Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 48 additions & 23 deletions src/TiledArray/dist_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
#include "TiledArray/util/initializer_list.h"
#include "TiledArray/util/random.h"

#include <cstdlib>
#include <madness/world/parallel_archive.h>
#include <cstdlib>

namespace TiledArray {

Expand All @@ -49,8 +49,7 @@ class Tensor;
/// used to construct distributed tensor algebraic operations.
/// \tparam T The element type of for array tiles
/// \tparam Tile The tile type [ Default = \c Tensor<T> ]
template <typename Tile = Tensor<double>,
typename Policy = DensePolicy>
template <typename Tile = Tensor<double>, typename Policy = DensePolicy>
class DistArray : public madness::archive::ParallelSerializableObject {
public:
typedef TiledArray::detail::ArrayImpl<Tile, Policy>
Expand Down Expand Up @@ -299,27 +298,41 @@ class DistArray : public madness::archive::ParallelSerializableObject {
/// `{{1, 2}, {3, 4, 5}}`). If an exception is
/// raised \p world and \p il are unchanged.
template <typename T>
DistArray(World& world, detail::vector_il<T> il)
DistArray(World& world,
std::initializer_list<T>
il) // N.B. clang does not like detail::vector_il<T> here
: DistArray(array_from_il<DistArray>(world, il)) {}

template <typename T>
DistArray(World& world, detail::matrix_il<T> il)
DistArray(World& world, std::initializer_list<std::initializer_list<T>> il)
: DistArray(array_from_il<DistArray>(world, il)) {}

template <typename T>
DistArray(World& world, detail::tensor3_il<T> il)
DistArray(
World& world,
std::initializer_list<std::initializer_list<std::initializer_list<T>>> il)
: DistArray(array_from_il<DistArray>(world, il)) {}

template <typename T>
DistArray(World& world, detail::tensor4_il<T> il)
DistArray(World& world, std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<T>>>>
il)
: DistArray(array_from_il<DistArray>(world, il)) {}

template <typename T>
DistArray(World& world, detail::tensor5_il<T> il)
DistArray(World& world,
std::initializer_list<std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<T>>>>>
il)
: DistArray(array_from_il<DistArray>(world, il)) {}

template <typename T>
DistArray(World& world, detail::tensor6_il<T> il)
DistArray(
World& world,
std::initializer_list<
std::initializer_list<std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<T>>>>>>
il)
: DistArray(array_from_il<DistArray>(world, il)) {}
///@}

Expand Down Expand Up @@ -350,27 +363,42 @@ class DistArray : public madness::archive::ParallelSerializableObject {
/// `{{1, 2}, {3, 4, 5}}`). If an exception is
/// raised \p world and \p il are unchanged.
template <typename T>
DistArray(World& world, const trange_type& trange, detail::vector_il<T> il)
DistArray(World& world, const trange_type& trange,
std::initializer_list<T> il)
: DistArray(array_from_il<DistArray>(world, trange, il)) {}

template <typename T>
DistArray(World& world, const trange_type& trange, detail::matrix_il<T> il)
DistArray(World& world, const trange_type& trange,
std::initializer_list<std::initializer_list<T>> il)
: DistArray(array_from_il<DistArray>(world, trange, il)) {}

template <typename T>
DistArray(World& world, const trange_type& trange, detail::tensor3_il<T> il)
DistArray(
World& world, const trange_type& trange,
std::initializer_list<std::initializer_list<std::initializer_list<T>>> il)
: DistArray(array_from_il<DistArray>(world, trange, il)) {}

template <typename T>
DistArray(World& world, const trange_type& trange, detail::tensor4_il<T> il)
DistArray(World& world, const trange_type& trange,
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<T>>>>
il)
: DistArray(array_from_il<DistArray>(world, trange, il)) {}

template <typename T>
DistArray(World& world, const trange_type& trange, detail::tensor5_il<T> il)
DistArray(World& world, const trange_type& trange,
std::initializer_list<std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<T>>>>>
il)
: DistArray(array_from_il<DistArray>(world, trange, il)) {}

template <typename T>
DistArray(World& world, const trange_type& trange, detail::tensor6_il<T> il)
DistArray(
World& world, const trange_type& trange,
std::initializer_list<
std::initializer_list<std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<T>>>>>>
il)
: DistArray(array_from_il<DistArray>(world, trange, il)) {}
/// @}

Expand Down Expand Up @@ -978,8 +1006,7 @@ class DistArray : public madness::archive::ParallelSerializableObject {
/// DistArray::check_str_index()
auto operator()(const std::string& vars) const {
check_str_index(vars);
return TiledArray::expressions::TsrExpr<const DistArray>(*this,
vars);
return TiledArray::expressions::TsrExpr<const DistArray>(*this, vars);
}

/// Create a tensor expression
Expand Down Expand Up @@ -1215,9 +1242,8 @@ class DistArray : public madness::archive::ParallelSerializableObject {
/// @tparam Archive an Archive type
/// @warning this does not fence; it is user's responsibility to do that
template <typename Archive,
typename = std::enable_if_t<
!Archive::is_parallel_archive &&
madness::is_output_archive_v<Archive>>>
typename = std::enable_if_t<!Archive::is_parallel_archive &&
madness::is_output_archive_v<Archive>>>
void serialize(const Archive& ar) const {
// serialize array type, world size, rank, and pmap type to be able
// to ensure same data type and same data distribution expected
Expand All @@ -1235,9 +1261,8 @@ class DistArray : public madness::archive::ParallelSerializableObject {
/// @tparam Archive an Archive type
/// @warning this does not fence; it is user's responsibility to do that
template <typename Archive,
typename = std::enable_if_t<
!Archive::is_parallel_archive &&
madness::is_input_archive_v<Archive>>>
typename = std::enable_if_t<!Archive::is_parallel_archive &&
madness::is_input_archive_v<Archive>>>
void serialize(const Archive& ar) {
auto& world = TiledArray::get_default_world();

Expand Down
10 changes: 10 additions & 0 deletions tests/initializer_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,12 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(vector, T, scalar_type_list) {
for (tile_type tile : array) {
BOOST_CHECK(std::equal(tile.begin(), tile.end(), il.begin()));
}

// test that can construct (sparse array) directly from an initializer list
auto sp_array = TSpArray<T>(world, {T{1}, T{2}, T{3}});
for (auto&& tile : sp_array) {
BOOST_CHECK(std::equal(tile.get().begin(), tile.get().end(), il.begin()));
}
}

BOOST_AUTO_TEST_CASE_TEMPLATE(empty_matrix, T, scalar_type_list) {
Expand Down Expand Up @@ -566,6 +572,10 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(square_matrix, T, scalar_type_list) {
}
}
}

// test that can construct (sparse array) directly from an initializer list
auto sp_array = TSpArray<T>(
world, {{T{1}, T{2}, T{3}}, {T{4}, T{5}, T{6}}, {T{7}, T{8}, T{9}}});
}

BOOST_AUTO_TEST_CASE_TEMPLATE(tall_matrix, T, scalar_type_list) {
Expand Down