Skip to content

Device support for TA::Tensor via TA::UMTensor#554

Draft
ajay-mk wants to merge 21 commits into
masterfrom
ajay/feature/umtensor
Draft

Device support for TA::Tensor via TA::UMTensor#554
ajay-mk wants to merge 21 commits into
masterfrom
ajay/feature/umtensor

Conversation

@ajay-mk
Copy link
Copy Markdown
Member

@ajay-mk ajay-mk commented May 22, 2026

Summary

Adds device (CUDA/HIP) support to TA::Tensor through a new unified-memory tile type, TA::UMTensor<T> backed by Umpire-managed unified memory allocator, with tile ops routed through the device task machinery.

Changes

  • Introduces TA::UMTensor<T>
  • is_device_tile<UMTensor<T>> specialization so the expression engine routes ops through madness::add_device_task.
  • to_device / to_host UM prefetch helpers
  • Overloads of tile level operators (see src/TiledArray/device/tensor.h)
  • Array level helpers for host to device conversions.
  • Tests:tests/expressions_device_tensor.cpp
  • Examples: examples/device/ta_dense_um_tensor.cpp, ta_vector_um_tensor.cpp

Changes are ported over from experimental #531 with help from Claude Code, the current version is tested with MPQC with UMTensor as the device tile type.

ajay-mk added 21 commits May 12, 2026 21:52
Adds the entry point for native TA::Tensor support on devices, paralleling
the existing btas_um_tensor path. UMTensor<T> is a TA::Tensor instantiated
on device_um_allocator<T>, which already maps to MemorySpace::Device_UM
via platform.h:90.

* fwd.h: UMTensor<T> typedef under TILEDARRAY_HAS_DEVICE.
* device/tensor.h: is_device_tile<UMTensor<T>> partial specialization so
  the existing pass-through specs for Tile<T> and LazyArrayTile<T, Op>
  (tensor/type_traits.h) classify UMTensor-backed tiles as device tiles
  -- this is the gate that routes the expression engine through
  madness::add_device_task at binary_eval/unary_eval/contraction_eval and
  the 6 sites in expressions/expr.h. Also adds detail::to_device /
  detail::to_host prefetch helpers that go directly through
  device::memPrefetchAsync (TA::Tensor stores data via shared_ptr<T[]>
  rather than a varray, so we cannot route through to_execution_space).
* device/tensor.cpp: static_asserts pinning down that the trait fires
  for UMTensor<{double,float,complex<double>}>, propagates through
  Tile<>, and does not misclassify plain Tensor<double>. Placeholder
  for the explicit instantiations that will land in Phase 4.
* src/CMakeLists.txt: hook the two files into the TILEDARRAY_HAS_HIP OR
  TILEDARRAY_HAS_CUDA source list, alongside btas_um_tensor.{h,cpp}.

No tile-op overloads yet -- Phase 1 is compile-only.
Adds the smallest tile-op set that the expression engine needs in order
to evaluate a `C("ij") = a*A("ik")*B("kj")`-style expression plus a norm
on UMTensor: clone, scale/scale_to, neg/neg_to, add/add_to (+ scaled
forms), subt/subt_to (+ scaled forms), dot, squared_norm, norm, and gemm
(returning + accumulating). Element-wise mult, permute, shift, and
batched paths are deliberately not included yet -- they need librett /
custom kernels and a clear nbatch story, and dragging them in here
would obscure whether the dispatch surface alone is correct.

Each overload is a concrete-type free function in `namespace TiledArray`
so ADL prefers it over the generic templated forwarders in
`tile_op/tile_interface.h`. No constraint relaxation, no member-function
revert -- the dispatch falls out of overload partial ordering.

The kernel pattern mirrors `device/btas.h`: resolve a BLAS++ queue from
`blasqueue_for(range)`, prefetch every input + result to the device,
call into BLAS++, then `sync_madness_task_with(stream)` so the wrapping
MADNESS device task waits for the queue to drain. We do NOT thread an
explicit `blas::Queue&` through composite ops -- `stream_for(range)` in
`external/device.h` already returns the current task's stream when
invoked inside a device task, which is what we want.

For now batched tiles (`nbatch_ > 1`) are asserted away; the expression
engine doesn't currently route batched UMTensor through these ops and
silently miscomputing would be worse than a clear assert.

`device/tensor.cpp` grows a tiny instantiation probe that exercises the
full Tier-1 surface for `double` and `float`, so anything that doesn't
type-check breaks the build immediately rather than waiting on Phase 5
tests. Real explicit instantiations land in Phase 4.
Rounds out Tier 1 with the ops that require librett or the device
element-wise kernel:

* permute(arg, Permutation)         -- librett_permute on the tile data
* permute(arg, BipartitePermutation)-- forwards to outer(perm); required
                                       to win ADL against the generic
                                       member-delegating overload (the
                                       analogous comment lives in
                                       device/btas_um_tensor.h:193)
* shift(arg, bound_shift)           -- copy + Range::inplace_shift
* shift_to(arg, bound_shift)        -- in-place range shift, no copy
* mult / mult_to (+ scaled + permuted variants) via device::mult_kernel
  and device::mult_to_kernel
* scale(a, f, perm), neg(a, perm), add(..., perm), subt(..., perm), and
  their scaled-with-perm forms -- thin compositions over the
  non-permuted core and the new permute

All sit in `namespace TiledArray`; ADL wins over the
`tile_op/tile_interface.h` defaults the same way it does for the
non-permuted core. The instantiation probe in device/tensor.cpp now
exercises the full Tier-1 surface (Phase 2a + 2b).
…ssion tests

Three bug categories surfaced once expressions actually ran end-to-end:

1. **Scale-factor semantics on `*_to(arg, factor)` were inverted.**
   `TA::Tensor::add_to(r, factor)` is `(this += r) *= factor`, and
   `TA::Tensor::subt_to(r, factor)` is `(this -= r) *= factor` -- the
   scaling applies to the result of the in-place op, not to `arg` alone.
   My initial implementations computed `result += factor * arg` and
   `result -= factor * arg`, which gives wrong values whenever the
   engine calls `subt_to(std::move(second), first, -1)` from
   `tile_op/subt.h:117` (the "right is consumable" branch -- expected
   to compute `(second - first) * -1 = first - second`). The new
   implementations chain `add_to/subt_to(no factor)` + `scale_to`,
   matching TA::Tensor's `(l += r) *= factor` pattern.

2. **Engine passes consumable operands as rvalues.** Subt::eval and
   ScalSubt::eval at `tile_op/subt.h:117,302` pass the result tile via
   `std::move(...)`. A plain `UMTensor<T>&` overload is not a viable
   candidate for an rvalue argument, so overload resolution falls
   through to the generic forwarder in `tile_op/tile_interface.h`
   (and `tile_interface/scale.h`) that delegates to TA::Tensor's CPU
   member function. The CPU member reads UM memory while the prior
   device kernel is still in flight on the queue -- silently
   miscomputing. Every in-place op (scale_to, neg_to, add_to[x2],
   subt_to[x2], mult_to[x2], shift_to) now has two concrete-type
   overloads: `UMTensor<T>&` and `UMTensor<T>&&`. Concrete types beat
   the templated forwarding reference in partial ordering regardless
   of constraint shape -- a constrained forwarding-ref version would
   in principle also win (constrained > unconstrained per [temp.constr.
   order]) but g++ does not consistently treat tile_interface's
   `enable_if`-only forwarders as unconstrained for this purpose; the
   two-concrete-overload form is robust.

3. **No correctness coverage existed for the dispatch path.** The
   previous `tensor_device.cpp` exercised tile-level ops directly from
   the main thread, which is the wrong contract: tile ops cooperate
   with the enclosing `madness::add_device_task` for stream sync and
   never run that way in production. Replaced with
   `tests/expressions_device_tensor.cpp`, mirroring the structure of
   `expressions_device_um.cpp` (the existing btas device test). 18
   cases including: trait classification, direct assign, permute,
   scale, neg, add/subt (+ with-permute / -to / -with-factor variants),
   scaled-subt isolations on left/right, mixed linear combination
   (catches bug #1), Hadamard, contraction, norm2 / dot reductions,
   and a `reuse_stress` case repeating `dot(a, a)` 8x (catches the
   LazyArrayTile conversion race in MPQC's pattern -- expected to be
   a master-branch baseline failure not introduced here). All 18 pass.

Also documents the `UMTensorArg` concept inline as the marker for
"this is a UMTensor (any cv/ref qualifier)" -- kept around as
documentation of intent even where the dispatch tiebreak forced us
to use concrete-type overloads instead.
…m, contraction variants)

Goes from 18 to 59 test cases, covering the patterns the expression
engine uses in production but the original cut didn't exercise. New
coverage:

Elementary ops + permutations:
* In-place expression operators (+=, -=, *=, += with permute).
* Negation in compounds: -(2*(a+b)), -a("c,b,a").
* All scale/permute combinations for add/subt/mult: scale_add,
  scale_add_permute, subt_permute, scale_subt, scale_subt_permute,
  mult_permute, scale_mult, scale_mult_permute.

Dataflow + reduction:
* Multi-step dataflow chain (t = a + b; c = 2*t - a) -- engine wires
  t's Future<Tile> into the next dist-eval without an intervening
  fence (per CLAUDE.md's synchronization hierarchy).
* Contraction-plus-reduce (norm2 of a contraction result).
* no_alias() + reduce -- exercises the LHS-doesn't-alias-RHS
  optimization through to a reduction.

Block expressions (PR 531 trouble area):
* Basic block assign / scaled-sum / accumulate.
* const_block: block from a const reference.
* scal_block: 2 * a.block(...).
* permute_block: a("c,b,a").block(...).
* assign_sub_block: write into a tile sub-block of an existing array.
* block_contract, block_permute_contract: blocks fed into GEMM.

Contraction variants:
* Outer product (rank-changing GEMM, no shared contraction index).
* Permuted result (c("k,i") = a("i,j") * b("j,k")).
* Transpose-on-right input (c("i,k") = a("i,j") * b("k,j")).
* CC-style rank-4: r("a,c") = t("a,b,k,l") * v("c,b,k,l").
* scale_cont, scale_cont_permute, scale_cont_with_input_transpose
  -- scale-fuse-into-GEMM paths.
* cont_non_uniform_split_inner / _two_inner -- catches GEMM kernels
  that silently assume uniform tile shapes.

TA::einsum entry point:
* Matmul, Hadamard, two-index contraction. Documents the one pattern
  not covered (`einsum("ij,jk->ijk")` with an index in both inputs
  and output) -- it segfaults inside einsum's internals on master
  regardless of allocator, out of scope for this branch.

Dot variants:
* dot_permute: dot of permuted arrays.
* dot_contr: dot of two contraction expressions (one tier deeper than
  btas-device's NO_THROW-only version; we validate the scalar value).

Tolerances are 5e-14 for non-GEMM ops, 1e-10 to 1e-9 for GEMM-bearing
paths to absorb summation-order differences between BLAS++ on the
device and the Eigen-based CPU reference.

Validation:
* All 59 device-tensor cases pass.
* Full np=1 ta_test suite (1880 cases, 12.56M assertions) still
  green -- no regressions from this branch.
Phase 3 surface: serialization plus the bulk DistArray-level helpers
(prefetch and host<->device conversion) needed by downstream code that
wants to round-trip UMTensor data through MADNESS archives or shuttle
DistArrays between memory spaces.

src/TiledArray/device/tensor.h:

* `to_host(DistArray<UMTensor<T>, P>&)` and
  `to_device(DistArray<UMTensor<T>, P>&)`: bulk prefetch every local
  tile, fence the world, deviceSynchronize on exit. Mirrors the btas
  helpers in btas_um_tensor.h:567-617 but takes the bare UMTensor
  (no `TA::Tile<>` wrapper, per CLAUDE.md).
* `um_tensor_to_ta_tensor(DistArray<UMTensor<T>, P>)` and
  `ta_tensor_to_um_tensor(DistArray<TA::Tensor<T>, P>)`: tile-by-tile
  conversion via `to_new_tile_type`. The per-tile lambda allocates a
  result of the target tile type, prefetches the source as needed,
  and memcpys -- since both sides are TA::Tensor and only the
  allocator differs, no per-element conversion is required.
* `madness::archive::ArchiveStoreImpl<Archive, UMTensor<T>>`:
  prefetches the tile to host before serializing, then writes the
  same fields TA::Tensor::serialize would
  (empty/range/nbatch/wrap(data)). The default member serialize is
  not safe to use as-is because UM data may be stale on the host
  while a device kernel is in flight.
* `madness::archive::ArchiveLoadImpl<Archive, UMTensor<T>>`:
  reconstructs the tensor in UM (writes go through host pages of UM
  -- if downstream code wants the data on the device, it should call
  `to_device` explicitly).

tests/expressions_device_tensor.cpp:

Five new cases (now 64 total):
* serialize_um_tensor: single-tile round-trip through BufferOutput/
  Input archives.
* serialize_um_tensor_empty: empty branch in Store/Load.
* um_to_ta_round_trip: device array -> host array -> device array,
  values preserved across both legs.
* um_to_ta_then_expression: a host expression on a converted-from-
  device array matches the same expression on the host mirror.
* bulk_prefetch_round_trip: to_host/to_device on a DistArray are
  no-ops for correctness (they only adjust page residency hints).

Validation:
* All 64 device-tensor cases pass.
* Full np=1 ta_test suite: 1885 cases, 12.64M assertions -- still
  green; no regressions from this branch.
Phase 4: instantiate `Tensor<T, device_um_allocator<T>>` once in
device/tensor.cpp and `extern template`-declare them in
device/tensor.h, so each TU including device/tensor.h does not
re-instantiate the full ~3000-line Tensor class body.

Mirrors the pattern at the bottom of src/TiledArray/tensor/tensor.h
+ src/TiledArray/tensor/tensor.cpp for the host-side instantiations.
The instantiated set is double, float, complex<{double,float}>, int,
long -- a superset of the host set (which omits int/long) for parity
with btas_um_tensor.cpp.

BLAS-bearing free functions (gemm, scale, axpy-driven add/subt, ...)
are left as header-defined templates; explicitly instantiating them
would pull the full BLAS++/librett surface into device/tensor.cpp,
and the build-time saving from extern-templating them does not
justify it. They get instantiated lazily in whichever TU actually
calls them (typically the test or example TU).

Validation:
* All 64 device-tensor cases still pass.
* Full np=1 ta_test suite: 1885 cases, 12.63M assertions -- still
  green; no regressions from this change.
Phase 6: two example programs that exercise the UMTensor surface
end-to-end through real timing loops, mirroring the existing
btas-based ta_dense_device.cpp and ta_vector_device.cpp but using
the bare UMTensor tile type (no `TA::Tile<>` wrapper).

* examples/device/ta_dense_um_tensor.cpp:
    c(Nm,Nn) = a(Nm,Nk) * b(Nk,Nn) on UMTensor<double> tiles, blocked
    by user-supplied Bm/Bn/Bk. Reports per-iteration wall time and
    GFLOPS, then verifies every result element equals the analytic
    Nk * val_a * val_b. Honors `cudaProfilerStart/Stop` when built
    with CUDA so the timed loop is profilable separately from setup.

* examples/device/ta_vector_um_tensor.cpp:
    Element-wise op benchmark -- add, subt, scale, Hadamard, permute,
    in-place axpy -- on Nm x Nn UMTensor matrices. Reports per-op
    average wall time and effective bandwidth (counting one read +
    one write per element for unary ops, two reads + one write for
    binary ops).

Both examples follow the convention in CLAUDE.md / the rest of TA's
device examples: TA_SCOPED_INITIALIZE for runtime setup, world fence
before reading device-side results, exception-catching `main` wrapper.
Hooked into examples/device/CMakeLists.txt's `foreach(_exec ...)`
list so they build alongside the existing examples whenever
TILEDARRAY_HAS_CUDA OR TILEDARRAY_HAS_HIP.

Smoke-validated locally:
  ta_dense_um_tensor 256 64 256 64 256 64 3
    -> Verification PASSED
  ta_vector_um_tensor 512 128 512 128 3
    -> all six ops reported timings without error
The runtime instantiation probe (`compile_test_tier1<T>` plus the two
`instantiate_tier1_*` function pointers) was added in the Phase 2 commits
to force template type-checking of the device tile-op overloads before
any test exercised them. It is no longer load-bearing:

* tests/expressions_device_tensor.cpp (64 cases) calls every tier-1
  overload through the expression engine for `double`, so any
  instantiation breakage shows up at test-build time.
* device/tensor.cpp's explicit `template class Tensor<T,
  device_um_allocator<T>>` instantiations cover the class members
  authoritatively.

Removing the probe trims ~75 lines of dead code from the .cpp. The
`is_device_tile_v` static_asserts are kept -- they are zero-cost and
guarantee trait correctness even when BUILD_TESTING=OFF.
* Remove the `UMTensorArg` concept and the `detail::is_um_tensor` trait
  it was built on. Both were introduced during the dispatch debugging
  to try a constrained forwarding-reference approach (per the comment
  block at the top of the section); that approach was abandoned in
  favor of two concrete-type overloads (`UMTensor<T>&` and
  `UMTensor<T>&&`) per in-place op, which beat the templated
  forwarding-ref candidate in partial ordering. The concept and trait
  are dead declarations; the explanatory comment is kept and trimmed.

* Rewording: two inline comments referencing "Phase 2" and "Phase 3"
  development sequencing -- meaningless to a future reader -- are
  reworded to describe the rule without the historical label.

Verified: 64-case device_tensor_expressions_suite still passes; no
behavior change.
Widen um_tensor_to_ta_tensor / ta_tensor_to_um_tensor from <T, Policy> to
<UMTile, HostTile, Policy>, mirroring the signature in
device/btas_um_tensor.h:619+. Identity overloads cover the case where source
and destination tiles coincide.
The tile-op overloads for UMTensor<T> dispatch BLAS-on-pointer kernels
(scal, copy, gemm, dot, etc.). When CCk is instantiated for the device
tile, it composes nested tiles of the form UMTensor<UMTensor<double>>
to represent its outer/inner structure; those nested tiles must route
through TA::Tensor member ops, not the device kernels.

Add `requires TiledArray::detail::is_numeric_v<T>` to each device tile-op
overload, and gate `is_device_tile<UMTensor<T>>` and the archive store
helper on the same predicate so non-numeric element types fall through
to the host tile path.
The returning overload was missing left_right_congruent checks, and the
in-place overload was missing the full left_result/right_result/left_right
congruence set. Silent geometry mismatches would have produced wrong
results instead of asserting. Mirrors the asserts in tensor/kernels.h's
host-side detail::gemm worker.
detail::to_host wraps memPrefetchAsync, so the immediate host reads in
um_tensor_to_ta_tensor's convert_tile lambda and in ArchiveStoreImpl::store
were racing the prefetch on Pascal+ devices with concurrent_managed_access.
Drain the stream via sync_madness_task_with before the host walks the tile.
- shift_to: call Tensor::shift_to instead of const_cast'ing the range.
  TA::Tensor exposes a public shift_to member (unlike btas::Tensor),
  so the const_cast inherited from btas_um_tensor.h is unnecessary here.
- apply_scale_factor: flatten 3-level nested if constexpr into one
  else-if-constexpr cascade.
@ajay-mk ajay-mk mentioned this pull request May 22, 2026
10 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant