From d19df41498955ae7602ee29a9dc0d68d2867a803 Mon Sep 17 00:00:00 2001 From: Bimal Gaudel Date: Mon, 8 Jul 2024 07:41:21 -0400 Subject: [PATCH 1/3] Revert "ToT support for `math/linalg` functions and `concat` function." This reverts commit 52b09600924d68092d85c766daa3455be74040ba. --- src/TiledArray/conversions/concat.h | 5 ++--- src/TiledArray/math/linalg/basic.h | 7 +++---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/TiledArray/conversions/concat.h b/src/TiledArray/conversions/concat.h index 398a5dc7b3..7c440c54e2 100644 --- a/src/TiledArray/conversions/concat.h +++ b/src/TiledArray/conversions/concat.h @@ -92,9 +92,8 @@ DistArray concat( DistArray result(*target_world, tr); const auto annot = detail::dummy_annotation(r); for (auto i = 0ul; i != arrays.size(); ++i) { - result.make_tsrexpr(annot).block(tile_begin_end[i].first, - tile_begin_end[i].second) = - arrays[i].make_tsrexpr(annot); + result(annot).block(tile_begin_end[i].first, tile_begin_end[i].second) = + arrays[i](annot); } result.world().gop.fence(); diff --git a/src/TiledArray/math/linalg/basic.h b/src/TiledArray/math/linalg/basic.h index c00a363286..856c915bbe 100644 --- a/src/TiledArray/math/linalg/basic.h +++ b/src/TiledArray/math/linalg/basic.h @@ -79,14 +79,14 @@ template inline void vec_multiply(DistArray& a1, const DistArray& a2) { auto vars = TiledArray::detail::dummy_annotation(rank(a1)); - a1.make_tsrexpr(vars) = a1.make_tsrexpr(vars) * a2.make_tsrexpr(vars); + a1(vars) = a1(vars) * a2(vars); } template inline void scale(DistArray& a, S scaling_factor) { using numeric_type = typename DistArray::numeric_type; auto vars = TiledArray::detail::dummy_annotation(rank(a)); - a.make_tsrexpr(vars) = numeric_type(scaling_factor) * a.make_tsrexpr(vars); + a(vars) = numeric_type(scaling_factor) * a(vars); } template @@ -99,8 +99,7 @@ inline void axpy(DistArray& y, S alpha, const DistArray& x) { using numeric_type = typename DistArray::numeric_type; auto vars = TiledArray::detail::dummy_annotation(rank(y)); - y.make_tsrexpr(vars) = - y.make_tsrexpr(vars) + numeric_type(alpha) * x.make_tsrexpr(vars); + y(vars) = y(vars) + numeric_type(alpha) * x(vars); } /// selector for concat From 8ae0cdaa81fa4f837ca9ffd4f90437b218a5a504 Mon Sep 17 00:00:00 2001 From: Bimal Gaudel Date: Mon, 8 Jul 2024 10:59:51 -0400 Subject: [PATCH 2/3] Bug fix the logic in einsum function that delegates evaluation to the expression layer. --- src/TiledArray/einsum/tiledarray.h | 19 ++++++++++++++----- tests/einsum.cpp | 6 ++++++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/src/TiledArray/einsum/tiledarray.h b/src/TiledArray/einsum/tiledarray.h index 4777a8656e..72dea21231 100644 --- a/src/TiledArray/einsum/tiledarray.h +++ b/src/TiledArray/einsum/tiledarray.h @@ -526,12 +526,21 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, auto i = (a & b) - h; // - // - no Hadamard indices for non-nested DistArray imply evaluation can be - // delegated to expression layer. - // - only Hadamard indices for nested and non-nested DistArray imply - // evaluation can be delegated to expression layer. + // *) Pure Hadamard indices: (h && !(i || e)) is true implies + // the evaluation can be delegated to the expression layer + // for distarrays of both nested and non-nested tensor tiles. + // *) If no Hadamard indices are present (!h) the evaluation + // can be delegated to the expression _only_ for distarrays with + // non-nested tensor tiles. + // This is because even if Hadamard indices are not present, a contracted + // index might be present pertinent to the outer tensor in case of a + // nested-tile distarray, which is especially handled within this + // function because expression layer cannot handle that yet. // - if ((!IsArrayToT && !h) || (h && !(i || e))) { + if ((h && !(i || e)) // pure Hadamard + || (IsArrayToT && !(i || h)) // ToT result from outer-product + || (IsArrayT && !h) // T from general product without Hadamard + ) { ArrayC C; C(std::string(c) + inner.c) = A * B; return C; diff --git a/tests/einsum.cpp b/tests/einsum.cpp index 646d636b65..6be4a4a99d 100644 --- a/tests/einsum.cpp +++ b/tests/einsum.cpp @@ -433,6 +433,12 @@ BOOST_AUTO_TEST_CASE(corner_cases) { {{0, 4, 8}, {0, 4}}, // {8}))); + BOOST_REQUIRE(check_manual_eval("il;bae,il;e->li;ab", // + {{0, 2}, {0, 4}}, // + {{0, 2}, {0, 4}}, // + {4, 2, 3}, // + {3})); + BOOST_REQUIRE( check_manual_eval("ijkl;abecdf,k;e->ijl;bafdc", // {{0, 2}, {0, 3}, {0, 4}, {0, 5}}, // From 6994163c27395b5247f92580953e0b2ca3697e29 Mon Sep 17 00:00:00 2001 From: Bimal Gaudel Date: Sun, 7 Jul 2024 15:16:06 -0400 Subject: [PATCH 3/3] ToT support for `math/linalg` functions and `concat` function. --- src/TiledArray/conversions/concat.h | 5 +++-- src/TiledArray/math/linalg/basic.h | 7 ++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/TiledArray/conversions/concat.h b/src/TiledArray/conversions/concat.h index 7c440c54e2..398a5dc7b3 100644 --- a/src/TiledArray/conversions/concat.h +++ b/src/TiledArray/conversions/concat.h @@ -92,8 +92,9 @@ DistArray concat( DistArray result(*target_world, tr); const auto annot = detail::dummy_annotation(r); for (auto i = 0ul; i != arrays.size(); ++i) { - result(annot).block(tile_begin_end[i].first, tile_begin_end[i].second) = - arrays[i](annot); + result.make_tsrexpr(annot).block(tile_begin_end[i].first, + tile_begin_end[i].second) = + arrays[i].make_tsrexpr(annot); } result.world().gop.fence(); diff --git a/src/TiledArray/math/linalg/basic.h b/src/TiledArray/math/linalg/basic.h index 856c915bbe..c00a363286 100644 --- a/src/TiledArray/math/linalg/basic.h +++ b/src/TiledArray/math/linalg/basic.h @@ -79,14 +79,14 @@ template inline void vec_multiply(DistArray& a1, const DistArray& a2) { auto vars = TiledArray::detail::dummy_annotation(rank(a1)); - a1(vars) = a1(vars) * a2(vars); + a1.make_tsrexpr(vars) = a1.make_tsrexpr(vars) * a2.make_tsrexpr(vars); } template inline void scale(DistArray& a, S scaling_factor) { using numeric_type = typename DistArray::numeric_type; auto vars = TiledArray::detail::dummy_annotation(rank(a)); - a(vars) = numeric_type(scaling_factor) * a(vars); + a.make_tsrexpr(vars) = numeric_type(scaling_factor) * a.make_tsrexpr(vars); } template @@ -99,7 +99,8 @@ inline void axpy(DistArray& y, S alpha, const DistArray& x) { using numeric_type = typename DistArray::numeric_type; auto vars = TiledArray::detail::dummy_annotation(rank(y)); - y(vars) = y(vars) + numeric_type(alpha) * x(vars); + y.make_tsrexpr(vars) = + y.make_tsrexpr(vars) + numeric_type(alpha) * x.make_tsrexpr(vars); } /// selector for concat