Skip to content

[feat] Support user-defined data parser for SimpleStorage backend#82

Merged
ji-huazhong merged 5 commits into
Ascend:mainfrom
0oshowero0:exec_remote
Apr 23, 2026
Merged

[feat] Support user-defined data parser for SimpleStorage backend#82
ji-huazhong merged 5 commits into
Ascend:mainfrom
0oshowero0:exec_remote

Conversation

@0oshowero0
Copy link
Copy Markdown
Collaborator

@0oshowero0 0oshowero0 commented Apr 21, 2026

Background

Users often need to store lightweight references (e.g., URLs, file paths, etc.) rather than the full data into TransferQueue to avoid the expensive loading and decoding processes happening within user code, which reduces a data copy.

image

Solution

This PR introduces support for a user-defined data parser in the SimpleStorage backend.

The kv_put and kv_batch_put methods now accept an optional data_parser callable. This parser is executed inside each SimpleStorageUnit at put time. It receives the raw field_data dictionary during the put request and should return a dictionary with the same structure, replacing reference values with the actual parsed data.

Limitations & Future Work

  • Synchronous Execution: In the current design, the data parser execution is synchronous and part of the put operation. This means the put request is only completed when the data parser finishes execution.
  • Backend Support: data_parser is currently only supported by the SimpleStorage backend.
  • Incorrect Metadata: Allowing user-provided functions to modify data in SimpleStorageUnit may lead to incorrect shape & dtype metadata, which is captured when the data is still in TransferQueueClient. This can lead to problems for RDMA transport, which leverages these metadata collected by TQ to restore tensor during get.

Demo Script

"""Demo: concurrent data_parser with separated single-sample logic.

This demo shows how to structure a data_parser so that:
1. The **core parser** only handles a **single sample**.
2. The **batch wrapper** uses asyncio to process all samples in parallel.
3. The wrapper is **synchronous to the outside**: it blocks until every
   sample finishes, so ``data_parser`` returning means data is ready.

Scenario:
- Users pass URL-like strings in a column.
- The parser sleeps 1 s per sample (simulating I/O / decode) and then
  creates a random tensor of the requested dtype & shape.
- Because the sleeps run concurrently via asyncio, a batch of N samples
  finishes in ~1 s instead of ~N s.
"""

import asyncio
import time

import ray
import torch
from tensordict import TensorDict, NonTensorStack

import transfer_queue as tq


# ---------------------------------------------------------------------------
# Core single-sample parser
# ---------------------------------------------------------------------------
def parse_url(url: str) -> torch.Tensor:
    """Parse a URL-like descriptor 'dtype:HxW' into a random tensor."""
    dtype_str, shape_str = url.split(":")
    dtype = getattr(torch, dtype_str)
    shape = [int(dim) for dim in shape_str.split("x")]
    return torch.randn(shape, dtype=dtype)


# ---------------------------------------------------------------------------
# Batch-level parser
# ---------------------------------------------------------------------------
def concurrent_batch_url_parser(field_data: dict) -> dict:
    """Batch-level data_parser executed inside SimpleStorageUnit.

    It receives a ``dict`` (not a TensorDict) where each value is a
    batched column.  For columns created from ``NonTensorStack`` the
    value is a plain ``list`` of Python objects.

    Workflow:
    1. Spawns one async task per list element.
    2. Waits until *all* tasks finish (``asyncio.gather``).
    3. Replaces the list with the list of results.

    Because ``asyncio.run`` blocks until the loop finishes, this function
    is **synchronous** to its caller: when it returns, every sample has
    been processed.

    Args:
        field_data: Mapping ``field_name -> batched_values``.  The dict
            keys must stay exactly the same; only values may be
            transformed in-place.

    Returns:
        The same dict with parsed values substituted.
    """
    if "data_to_be_parsed" not in field_data:
        return field_data

    urls:list[str] = field_data["data_to_be_parsed"]

    async def _async_parse_single(url: str) -> torch.Tensor:
        await asyncio.sleep(1.0)  # Add fixed delay per sample
        return parse_url(url)

    async def _process_all():
        tasks = [asyncio.create_task(_async_parse_single(url)) for url in urls]
        return await asyncio.gather(*tasks)

    start = time.perf_counter()
    field_data["data_to_be_parsed"] = asyncio.run(_process_all())
    elapsed = time.perf_counter() - start

    print(
        f"[data_parser] Processed {len(urls)} samples in {elapsed:.2f}s "
        f"(serial would be ~{len(urls)}.0s)"
    )
    return field_data


# ---------------------------------------------------------------------------
# Main demo flow
# ---------------------------------------------------------------------------
def main():
    ray.init(ignore_reinit_error=True)
    try:
        tq.init()

        batch_size = 32

        # Column that stays untouched
        normal_data = torch.randn(batch_size, 2)

        # Column to be parsed: URL-like strings describing dtype & shape.
        shapes = [(i % 4 + 1, i % 3 + 2) for i in range(batch_size)]
        urls = [f"float32:{h}x{w}" for h, w in shapes]
        data_to_be_parsed = NonTensorStack(*urls)

        data = TensorDict({
            "normal_data": normal_data,
            "data_to_be_parsed": data_to_be_parsed,
        }, batch_size=batch_size)

        keys = [f"sample_{i}" for i in range(batch_size)]

        # -------------------------------------------------------------------
        # Put with data_parser
        # -------------------------------------------------------------------
        put_start_time = time.perf_counter()
        meta = tq.kv_batch_put(
            keys=keys,
            partition_id="train",
            fields=data,
            data_parser=concurrent_batch_url_parser,
        )
        put_elapsed = time.perf_counter() - put_start_time
        print(f"Put succeeded. Fields: {meta.fields}")
        print(
            f"Total kv_batch_put time: {put_elapsed:.2f}s "
            f"(concurrency keeps it ~1s, not {batch_size}s)\n"
        )

        # -------------------------------------------------------------------
        # Fetch back and verify
        # -------------------------------------------------------------------
        result = tq.kv_batch_get(keys=keys, partition_id="train")

        # 1) normal_data unchanged
        torch.testing.assert_close(result["normal_data"], normal_data)
        print("[PASS] normal_data is unchanged.")

        # 2) Parsed tensors have correct dtype & shape
        expected_shapes = [(i % 4 + 1, i % 3 + 2) for i in range(batch_size)]
        for i, exp_shape in enumerate(expected_shapes):
            tensor = result["data_to_be_parsed"][i]
            assert tensor.dtype == torch.float32, (
                f"dtype mismatch at index {i}: expected torch.float32, got {tensor.dtype}"
            )
            assert tuple(tensor.shape) == exp_shape, (
                f"shape mismatch at index {i}: expected {exp_shape}, got {tuple(tensor.shape)}"
            )
        print(f"[PASS] All {batch_size} parsed tensors have correct dtype & shape.")

        # 3) Timing sanity check
        #    Serial execution would be ~batch_size seconds.
        #    Because asyncio tasks run in parallel, it should be ~1 s.
        #    We allow generous headroom for TQ network / serialization overhead.
        assert put_elapsed < 2.0, (
            f"Expected concurrent execution (~1s), but took {put_elapsed:.2f}s. "
            "Are the asyncio tasks actually running in parallel?"
        )
        print(f"[PASS] Timing looks concurrent: {put_elapsed:.2f}s < 2.0s")

        print("\n=== All verifications passed! ===")

        # wait for Ray log collect
        time.sleep(2)

    except Exception as e:
        print(f"Error: {type(e).__name__}: {e}")
        import traceback

        traceback.print_exc()
    finally:
        tq.close()
        ray.shutdown()


if __name__ == "__main__":
    main()

Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
@ascend-robot
Copy link
Copy Markdown

CLA Signature Pass

0oshowero0, thanks for your pull request. All authors of the commits have signed the CLA. 👍

@0oshowero0 0oshowero0 marked this pull request as ready for review April 21, 2026 11:41
Copilot AI review requested due to automatic review settings April 21, 2026 11:41
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an optional, user-provided data_parser hook to the KV put path so callers can store lightweight references and have them parsed/realized inside SimpleStorageUnit during put, reducing client-side decode/load work.

Changes:

  • Extend public KV APIs (kv_put, kv_batch_put, and async variants) to accept an optional data_parser.
  • Plumb data_parser through client → storage manager → ZMQ message → SimpleStorageUnit._handle_put, executing it before persisting.
  • Add a unit test for SimpleStorageUnit parsing behavior and update the tutorial notebook with a data_parser demo.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
tutorial/basic.ipynb Documents and demonstrates data_parser usage in the basic tutorial.
transfer_queue/storage/simple_backend.py Executes data_parser inside SimpleStorageUnit before storing field data.
transfer_queue/storage/managers/simple_backend_manager.py Adds data_parser plumbing into PUT ZMQ requests to storage units.
transfer_queue/storage/managers/base.py Extends storage manager interface; explicitly rejects data_parser for KV-based backends.
transfer_queue/interface.py Exposes data_parser on high-level KV functions (sync + async).
transfer_queue/client.py Propagates data_parser through put/async_put to the storage manager.
tests/test_simple_storage_unit.py Adds a storage-unit-level test covering data_parser.
tests/test_client.py Updates mocked put_data signatures to accept the new parameter.
Comments suppressed due to low confidence (1)

transfer_queue/interface.py:470

  • data_parser is accepted even when fields is None, but in that branch the call is treated as a tag-only update and data_parser is silently ignored. This is likely a user error; consider raising ValueError when data_parser is provided without fields (same for the async variants) to avoid surprising no-ops.
    if fields is None and tag is None:
        raise ValueError("Please provide at least one parameter of `fields` or `tag`.")

    tq_client = _maybe_create_transferqueue_client()

    # 1. translate user-specified key to BatchMeta
    batch_meta = tq_client.kv_retrieve_meta(keys=[key], partition_id=partition_id, create=True)

    if batch_meta.size != 1:
        raise RuntimeError(f"Retrieved BatchMeta size {batch_meta.size} does not match with input `key` size of 1!")

    # 2. register the user-specified tag to BatchMeta
    if tag is not None:
        batch_meta.update_custom_meta([tag])

    # 3. put data
    if fields is not None:
        if isinstance(fields, dict):
            # TODO: consider whether to support this...
            batch = {}
            for field_name, value in fields.items():
                if isinstance(value, torch.Tensor):
                    if value.is_nested:
                        raise ValueError("Please use (async)kv_batch_put for batch operation")
                    batch[field_name] = value.unsqueeze(0)
                else:
                    batch[field_name] = NonTensorStack(value)
            fields = TensorDict(batch, batch_size=[1])
        elif not isinstance(fields, TensorDict):
            raise ValueError("`fields` can only be dict or TensorDict")

        # After put, batch_meta.field_names will include the new fields written by user
        batch_meta = tq_client.put(fields, batch_meta, data_parser=data_parser)
    else:
        # Directly update custom_meta (tag) to controller
        tq_client.set_custom_meta(batch_meta)


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread transfer_queue/storage/managers/simple_backend_manager.py
Comment thread transfer_queue/storage/simple_backend.py
Comment thread tutorial/basic.ipynb Outdated
Comment thread tutorial/basic.ipynb
Comment on lines +920 to +937
"parser_fields = TensorDict(\n",
" {\n",
" # Normal data column that will not be modified by data parser\n",
" \"normal_data\": torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]),\n",
" # Each sample carries a list of ints describing the desired tensor shape\n",
" \"data_to_be_parsed\": NonTensorStack([2, 3], [1, 4], [3, 2]),\n",
" },\n",
" batch_size=3,\n",
")\n",
"\n",
"tq.kv_batch_put(\n",
" keys=parser_keys,\n",
" partition_id=\"train\",\n",
" fields=parser_fields,\n",
" data_parser=create_data_by_shape_parser,\n",
")\n",
"print(\"Stored 3 samples with data parser\")"
]
Copy link

Copilot AI Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The data_parser demo uses shape descriptors [[2, 3], [1, 4], [3, 2]], which produces tensors varying in more than one dimension and triggers the warning about jagged nested tensors only supporting a single ragged dimension. To keep the tutorial output clean and demonstrate the intended path, consider using shapes that only vary along one ragged dimension (e.g., keep the trailing dims constant) or adjust the narrative to explain the warning/fallback.

Copilot uses AI. Check for mistakes.
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
@ascend-robot
Copy link
Copy Markdown

CLA Signature Pass

0oshowero0, thanks for your pull request. All authors of the commits have signed the CLA. 👍

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 9 changed files in this pull request and generated 7 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread tutorial/basic.ipynb Outdated
Comment thread transfer_queue/storage/simple_backend.py
Comment on lines 297 to 302
Args:
data: TensorDict containing the data to store.
metadata: BatchMeta containing storage location information.
data_parser: Optional callable to parse reference data (e.g., URLs) into real
content. Executed distributedly on each SimpleStorageUnit.
"""
Copy link

Copilot AI Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since data_parser runs inside the storage unit, it can change value types/shapes compared to the input TensorDict. However, SimpleStorageManager.put_data still notifies the controller using a schema derived from the original data (pre-parse), so BatchMeta.field_schema in the controller/clients can become inconsistent with what is actually stored. Consider constraining data_parser to preserve schema, or extending the PUT response so storage units can return the post-parse schema for accurate controller updates.

Copilot uses AI. Check for mistakes.
Comment thread transfer_queue/client.py
Comment thread transfer_queue/interface.py
Comment on lines +407 to 410
data_parser: Optional callable to parse reference data (e.g., URLs) into real
content. Receives a dict of field_name -> batched values and should
return a dict with the same structure. Only supported by SimpleStorage.

Copy link

Copilot AI Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allowing a user-provided data_parser to be serialized and executed inside storage units effectively enables arbitrary code execution in the storage process. If TransferQueue can be used in multi-tenant or untrusted-client scenarios, this is a significant security boundary change. Consider documenting this explicitly (trusted clients only) and/or gating it behind an opt-in config flag for SimpleStorage to avoid accidental exposure.

Copilot uses AI. Check for mistakes.
Comment thread tutorial/basic.ipynb Outdated
_, put_get_address = storage_setup
client = MockStorageClient(put_get_address)

def create_data_by_shape_parser(field_data):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will need a clear input/output type definition of this function

return shapes, dtypes, custom_backend_meta_list

async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None:
async def put_data(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How to handle cases where the data_parser of each data item in the Tensordict is different?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. When the data type requiring specific processing can be detected from the pre-parsed data (e.g., via the URL, file extension, or data type metadata), we can implement conditional processing for the same field:
if "multi_modal_data" in field_data:
    parsed_url_data = []
    for pre_parser_data in field_data["multi_modal_data"].values():
        if isinstance(pre_parser_data, str):
          if pre_parser_data.endswith(".mp4"):
              parsed_url_data.append(process_mp4_url(pre_parser_data))
          elif pre_parser_data.endswith((".jpg", ".png")):
              parsed_url_data.append(process_image_url(pre_parser_data))
        else:
            parsed_url_data.append(pre_parser_data)
  1. If the data type is difficult to determine dynamically, we can simply separate the data that requires different parser functions into distinct fields.

Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
@ascend-robot
Copy link
Copy Markdown

CLA Signature Pass

0oshowero0, thanks for your pull request. All authors of the commits have signed the CLA. 👍

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 9 changed files in this pull request and generated 7 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread tutorial/basic.ipynb
Comment thread transfer_queue/storage/simple_backend.py
Comment thread transfer_queue/storage/simple_backend.py
Comment thread transfer_queue/interface.py
Comment thread transfer_queue/interface.py Outdated
a list. It must return a dict of the same format with the exact
same keys and the same number of elements per column; do not
change the inner order of values within each column. Only
supported by SimpleStorage.
Copy link

Copilot AI Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since data_parser is serialized with cloudpickle and executed inside SimpleStorageUnit, it effectively allows arbitrary code execution on storage processes. The docstring should explicitly warn that this is only safe in trusted environments / when clients are authorized to run code on the backend.

Suggested change
supported by SimpleStorage.
supported by SimpleStorage. Security warning: this callable is
serialized and executed on storage/backend processes, so it
effectively permits backend code execution. Only use this in
trusted environments and only allow authorized clients to supply
`data_parser`.

Copilot uses AI. Check for mistakes.
Comment thread tutorial/basic.ipynb Outdated
Comment thread tutorial/basic.ipynb Outdated
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
@ascend-robot
Copy link
Copy Markdown

CLA Signature Pass

0oshowero0, thanks for your pull request. All authors of the commits have signed the CLA. 👍

Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
@ascend-robot
Copy link
Copy Markdown

CLA Signature Pass

0oshowero0, thanks for your pull request. All authors of the commits have signed the CLA. 👍

@0oshowero0 0oshowero0 changed the title [feat] Support user-defined data parser for SimpleStorage backend [BREAKING][feat] Support user-defined data parser for SimpleStorage backend Apr 22, 2026
@vermouth1992
Copy link
Copy Markdown
Collaborator

Why this is a breaking change?

@ji-huazhong ji-huazhong changed the title [BREAKING][feat] Support user-defined data parser for SimpleStorage backend [feat] Support user-defined data parser for SimpleStorage backend Apr 23, 2026
@ji-huazhong ji-huazhong merged commit aec9192 into Ascend:main Apr 23, 2026
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants