diff --git a/src/TiledArray/dist_array.h b/src/TiledArray/dist_array.h index 03be325d05..259c9f9414 100644 --- a/src/TiledArray/dist_array.h +++ b/src/TiledArray/dist_array.h @@ -33,8 +33,8 @@ #include "TiledArray/util/initializer_list.h" #include "TiledArray/util/random.h" -#include #include +#include namespace TiledArray { @@ -49,8 +49,7 @@ class Tensor; /// used to construct distributed tensor algebraic operations. /// \tparam T The element type of for array tiles /// \tparam Tile The tile type [ Default = \c Tensor ] -template , - typename Policy = DensePolicy> +template , typename Policy = DensePolicy> class DistArray : public madness::archive::ParallelSerializableObject { public: typedef TiledArray::detail::ArrayImpl @@ -299,27 +298,41 @@ class DistArray : public madness::archive::ParallelSerializableObject { /// `{{1, 2}, {3, 4, 5}}`). If an exception is /// raised \p world and \p il are unchanged. template - DistArray(World& world, detail::vector_il il) + DistArray(World& world, + std::initializer_list + il) // N.B. clang does not like detail::vector_il here : DistArray(array_from_il(world, il)) {} template - DistArray(World& world, detail::matrix_il il) + DistArray(World& world, std::initializer_list> il) : DistArray(array_from_il(world, il)) {} template - DistArray(World& world, detail::tensor3_il il) + DistArray( + World& world, + std::initializer_list>> il) : DistArray(array_from_il(world, il)) {} template - DistArray(World& world, detail::tensor4_il il) + DistArray(World& world, std::initializer_list>>> + il) : DistArray(array_from_il(world, il)) {} template - DistArray(World& world, detail::tensor5_il il) + DistArray(World& world, + std::initializer_list>>>> + il) : DistArray(array_from_il(world, il)) {} template - DistArray(World& world, detail::tensor6_il il) + DistArray( + World& world, + std::initializer_list< + std::initializer_list>>>>> + il) : DistArray(array_from_il(world, il)) {} ///@} @@ -350,27 +363,42 @@ class DistArray : public madness::archive::ParallelSerializableObject { /// `{{1, 2}, {3, 4, 5}}`). If an exception is /// raised \p world and \p il are unchanged. template - DistArray(World& world, const trange_type& trange, detail::vector_il il) + DistArray(World& world, const trange_type& trange, + std::initializer_list il) : DistArray(array_from_il(world, trange, il)) {} template - DistArray(World& world, const trange_type& trange, detail::matrix_il il) + DistArray(World& world, const trange_type& trange, + std::initializer_list> il) : DistArray(array_from_il(world, trange, il)) {} template - DistArray(World& world, const trange_type& trange, detail::tensor3_il il) + DistArray( + World& world, const trange_type& trange, + std::initializer_list>> il) : DistArray(array_from_il(world, trange, il)) {} template - DistArray(World& world, const trange_type& trange, detail::tensor4_il il) + DistArray(World& world, const trange_type& trange, + std::initializer_list>>> + il) : DistArray(array_from_il(world, trange, il)) {} template - DistArray(World& world, const trange_type& trange, detail::tensor5_il il) + DistArray(World& world, const trange_type& trange, + std::initializer_list>>>> + il) : DistArray(array_from_il(world, trange, il)) {} template - DistArray(World& world, const trange_type& trange, detail::tensor6_il il) + DistArray( + World& world, const trange_type& trange, + std::initializer_list< + std::initializer_list>>>>> + il) : DistArray(array_from_il(world, trange, il)) {} /// @} @@ -978,8 +1006,7 @@ class DistArray : public madness::archive::ParallelSerializableObject { /// DistArray::check_str_index() auto operator()(const std::string& vars) const { check_str_index(vars); - return TiledArray::expressions::TsrExpr(*this, - vars); + return TiledArray::expressions::TsrExpr(*this, vars); } /// Create a tensor expression @@ -1215,9 +1242,8 @@ class DistArray : public madness::archive::ParallelSerializableObject { /// @tparam Archive an Archive type /// @warning this does not fence; it is user's responsibility to do that template >> + typename = std::enable_if_t>> void serialize(const Archive& ar) const { // serialize array type, world size, rank, and pmap type to be able // to ensure same data type and same data distribution expected @@ -1235,9 +1261,8 @@ class DistArray : public madness::archive::ParallelSerializableObject { /// @tparam Archive an Archive type /// @warning this does not fence; it is user's responsibility to do that template >> + typename = std::enable_if_t>> void serialize(const Archive& ar) { auto& world = TiledArray::get_default_world(); diff --git a/tests/initializer_list.cpp b/tests/initializer_list.cpp index 0fde6f207d..884f5c61fd 100644 --- a/tests/initializer_list.cpp +++ b/tests/initializer_list.cpp @@ -539,6 +539,12 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(vector, T, scalar_type_list) { for (tile_type tile : array) { BOOST_CHECK(std::equal(tile.begin(), tile.end(), il.begin())); } + + // test that can construct (sparse array) directly from an initializer list + auto sp_array = TSpArray(world, {T{1}, T{2}, T{3}}); + for (auto&& tile : sp_array) { + BOOST_CHECK(std::equal(tile.get().begin(), tile.get().end(), il.begin())); + } } BOOST_AUTO_TEST_CASE_TEMPLATE(empty_matrix, T, scalar_type_list) { @@ -566,6 +572,10 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(square_matrix, T, scalar_type_list) { } } } + + // test that can construct (sparse array) directly from an initializer list + auto sp_array = TSpArray( + world, {{T{1}, T{2}, T{3}}, {T{4}, T{5}, T{6}}, {T{7}, T{8}, T{9}}}); } BOOST_AUTO_TEST_CASE_TEMPLATE(tall_matrix, T, scalar_type_list) {