Skip to content

[FEA]: CTA-level load-balanced segmented scan for irregularly-sized segments #7741

@kierandidi

Description

@kierandidi

Is this a duplicate?

Area

General CCCL

Is your feature request related to a problem? Please describe.

This feature request relates to tmol, a GPU-accelerated biomolecular energy function in PyTorch with custom C++/CUDA kernels. The CUDA code there currently depends on moderngpu, which is no longer actively maintained.

The most critical moderngpu dependency is a load-balanced segmented scan used for batched forward/inverse kinematics on protein backbone chains. Proteins are modeled as kinematic trees, and computing atom positions requires a segmented prefix scan with 4x4 matrix composition as the operator, where each segment is a protein chain of different length.

The problem is irregular parallelism: a batch of proteins contains segments of wildly different lengths (10 atoms for a small peptide up to 1000+ for a large protein, with 50-200 segments per batch). Naive approaches (one CTA per segment) are extremely load-imbalanced. We need CTA-level building blocks for load-balanced segmented scans, which moderngpu provides but CCCL currently does not.

Describe the solution you'd like

A CCCL-native API for load-balanced segmented scan that supports:

  1. Merge-path based load balancing across CTAs — data elements AND segment boundaries are interleaved into a merged sequence so each CTA gets roughly nt * vt work items regardless of segment sizes (equivalent to moderngpu's cta_load_balance_t).

  2. CTA-level segmented scan — within each CTA, threads process their vt work items sequentially, resetting the accumulator at segment boundaries detected via merge flags. A CTA-wide segmented scan then combines results across threads while respecting segment boundaries (equivalent to moderngpu's cta_segscan_t).

  3. Multi-CTA carry-out propagation (spine scan) — carry-out values from all CTAs are scanned recursively (still segment-aware), with a downward sweep propagating accumulated carry-in values back to each CTA's elements, stopping at the first segment boundary in each CTA.

  4. User-defined non-commutative types and operators — our scan operates on 4x4 homogeneous transformation matrices (16 floats) with matrix composition as the operator. This is associative but not commutative.

  5. Segment boundaries as sorted offset arrays (not boolean head flags).

  6. Both inclusive and exclusive scan modes.

The current implementation is ~400 lines built on moderngpu's CTA-level primitives: kernel_segscan.cuh. It uses:

  • cta_load_balance_t for merge-path work distribution (lines 296-308)
  • cta_segscan_t for CTA-level segmented scan (lines 79-86)
  • cta_reduce_t for finding the first segment boundary in a CTA (lines 74-75)
  • Recursive spine_segreduce for multi-CTA propagation (lines 132-217)

The call site for the kinematic tree scan is in compiled.cuda.cu, where it scans over 4x4 homogeneous transformation matrices with a custom composition operator.

Ideally, the CCCL equivalent would look something like:

// Hypothetical API
cub::DeviceSegmentedScan::InclusiveScan(
d_temp_storage, temp_storage_bytes,
d_in, d_out,
num_items, num_segments, d_segment_offsets,
compose_op, identity,
stream,
/* load_balanced = */ true // use merge-path load balancing
);

Or as composable CTA-level building blocks similar to moderngpu's approach, which would allow custom per-element functors during the scan (our implementation uses a user-supplied functor f(index, seg, rank) to generate input values on-the-fly rather than reading from a flat array).

Describe alternatives you've considered

  • cub::DeviceSegmentedReduce: Only supports reduction, not scan (no prefix sums). Also doesn't do load balancing across segments.
  • cub::DeviceScan + manual segmentation: Requires pre-/post-processing to handle segment boundaries, losing the efficiency of the fused load-balanced approach.
  • cub::BlockScan with manual decomposition: Possible but would require reimplementing the merge-path load balancing, carry-out propagation, and spine scan — essentially rebuilding what moderngpu already provides.
  • Keep vendoring moderngpu: Works but moderngpu is unmaintained since 2021 and doesn't support newer architectures natively.

Additional context

  • The data type being scanned is a 16-float struct (4x4 matrix) with a non-commutative composition operator — not just int with plus.
  • Typical workload: O(100K) total elements across O(100) irregularly-sized segments.
  • The full codebase using this primitive is publicly available at github.com/uw-ipd/tmol, specifically:

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs triageIssues that require the team's attention

    Type

    No type

    Projects

    Status

    Todo

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions