Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/TiledArray/dist_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

#include <madness/world/parallel_archive.h>
#include <cstdlib>
#include <tuple>

namespace TiledArray {

Expand Down Expand Up @@ -1708,6 +1709,38 @@ auto norm2(const DistArray<Tile, Policy>& a) {
return std::sqrt(squared_norm(a));
}

template<typename Array, typename Tiles>
Array make_array(
World &world,
const detail::trange_t<Array> &tiled_range,
Tiles begin, Tiles end)
{
Array array;
using Tuple = std::remove_reference_t<decltype(*begin)>;
using Index = std::tuple_element_t<0,Tuple>;
using shape_type = typename Array::shape_type;
if constexpr (shape_type::is_dense()) {
array = Array(world, tiled_range);
}
else {
std::vector< std::pair<Index,float> > tile_norms;
for (Tiles it = begin; it != end; ++it) {
auto [index,tile] = *it;
tile_norms.push_back({index,tile.norm()});
}
shape_type shape(world, tile_norms, tiled_range);
array = Array(world, tiled_range, shape);
}
for (Tiles it = begin; it != end; ++it) {
auto [index,tile] = *it;
if (array.is_zero(index)) continue;
array.set(index,tile);
}
return array;
}



} // namespace TiledArray

// serialization
Expand Down
39 changes: 33 additions & 6 deletions src/TiledArray/einsum/tiledarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define TILEDARRAY_EINSUM_H__INCLUDED

#include "TiledArray/fwd.h"
#include "TiledArray/dist_array.h"
#include "TiledArray/expressions/fwd.h"
#include "TiledArray/einsum/index.h"
#include "TiledArray/einsum/range.h"
Expand Down Expand Up @@ -53,6 +54,8 @@ auto einsum(
{

using Array = std::remove_cv_t<Array_>;
using Tensor = typename Array::value_type;
using Shape = typename Array::shape_type;

auto a = std::get<0>(Einsum::idx(A));
auto b = std::get<0>(Einsum::idx(B));
Expand Down Expand Up @@ -103,6 +106,7 @@ auto einsum(
TiledRange ei_tiled_range;
Array ei;
std::string expr;
std::vector< std::pair<Einsum::Index<size_t>,Tensor> > local_tiles;
bool own(Einsum::Index<size_t> h) const {
for (Einsum::Index<size_t> ei : tiles) {
auto idx = apply_inverse(permutation, h+ei);
Expand Down Expand Up @@ -149,7 +153,6 @@ auto einsum(
}

using Index = Einsum::Index<size_t>;
using Tensor = typename Array::value_type;

if constexpr(std::tuple_size<decltype(cs)>::value > 1) {
TA_ASSERT(e);
Expand All @@ -169,7 +172,7 @@ auto einsum(
for (size_t i = 0; i < h.size(); ++i) {
batch *= H.batch[i].at(h[i]);
}
Tensor tile(TiledArray::Range{batch});
Tensor tile(TiledArray::Range{batch}, typename Tensor::value_type());
for (Index i : tiles) {
// skip this unless both input tiles exist
const auto pahi_inv = apply_inverse(pa,h+i);
Expand Down Expand Up @@ -208,6 +211,7 @@ auto einsum(
}

std::vector< std::shared_ptr<World> > worlds;
std::vector< std::tuple<Index,Tensor> > local_tiles;

// iterates over tiles of hadamard indices
for (Index h : H.tiles) {
Expand All @@ -222,21 +226,29 @@ auto einsum(
batch *= H.batch[i].at(h[i]);
}
for (auto &term : AB) {
term.ei = Array(*owners, term.ei_tiled_range);
term.local_tiles.clear();
const Permutation &P = term.permutation;
for (Index ei : term.tiles) {
auto idx = apply_inverse(P, h+ei);
if (!term.array.is_local(idx)) continue;
if (term.array.is_zero(idx)) continue;
auto tile = term.array.find(idx).get();
if (P) tile = tile.permute(P);
auto shape = term.ei.trange().tile(ei);
auto shape = term.ei_tiled_range.tile(ei);
tile = tile.reshape(shape, batch);
term.ei.set(ei, tile);
term.local_tiles.push_back({ei, tile});
}
term.ei = TiledArray::make_array<Array>(
*owners,
term.ei_tiled_range,
term.local_tiles.begin(),
term.local_tiles.end()
);
}
C.ei(C.expr) = (A.ei(A.expr) * B.ei(B.expr)).set_world(*owners);
for (Index e : C.tiles) {
if (!C.ei.is_local(e)) continue;
if (C.ei.is_zero(e)) continue;
auto tile = C.ei.find(e).get();
assert(tile.batch_size() == batch);
const Permutation &P = C.permutation;
Expand All @@ -245,14 +257,29 @@ auto einsum(
shape = apply_inverse(P, shape);
tile = tile.reshape(shape);
if (P) tile = tile.permute(P);
C.array.set(c, tile);
local_tiles.push_back({c, tile});
}
// mark for lazy deletion
A.ei = Array();
B.ei = Array();
C.ei = Array();
}

if constexpr (!Shape::is_dense()) {
TiledRange tiled_range = TiledRange(range_map[c]);
std::vector< std::pair<Index,float> > tile_norms;
for (auto& [index,tile] : local_tiles) {
tile_norms.push_back({index,tile.norm()});
}
Shape shape(world, tile_norms, tiled_range);
C.array = Array(world, TiledRange(range_map[c]), shape);
}

for (auto& [index,tile] : local_tiles) {
if (C.array.is_zero(index)) continue;
C.array.set(index, tile);
}

for (auto &w : worlds) {
w->gop.fence();
}
Expand Down
96 changes: 74 additions & 22 deletions tests/einsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,15 +494,28 @@ BOOST_AUTO_TEST_SUITE_END()
// TiledArray einsum expressions
BOOST_AUTO_TEST_SUITE(einsum_tiledarray)

template<typename T = Tensor<int>, typename ... Args>
using TiledArray::SparsePolicy;
using TiledArray::DensePolicy;

template<typename Policy, typename T = Tensor<int>, typename ... Args>
auto random(Args ... args) {
TiledArray::TiledRange tr{ {0, args}... };
auto& world = TiledArray::get_default_world();
TiledArray::DistArray<T,TiledArray::SparsePolicy> t(world,tr);
TiledArray::DistArray<T,Policy> t(world,tr);
t.fill_random();
return t;
}

template<typename T = Tensor<int>, typename ... Args>
auto sparse_zero(Args ... args) {
TiledArray::TiledRange tr{ {0, args}... };
auto& world = TiledArray::get_default_world();
TiledArray::SparsePolicy::shape_type shape(0.0f, tr);
TiledArray::DistArray<T,TiledArray::SparsePolicy> t(world,tr,shape);
t.fill(0);
return t;
}

template<int NA, int NB, int NC, typename T, typename Policy>
void einsum_tiledarray_check(
TiledArray::DistArray<T,Policy> &&A,
Expand All @@ -523,85 +536,124 @@ void einsum_tiledarray_check(
array_to_eigen_tensor<Tensor<U,NB>>(B)
);
auto result = array_to_eigen_tensor<TC>(C);
//std::cout << "e=" << result << std::endl;
BOOST_CHECK(isApprox(result, reference));
}

BOOST_AUTO_TEST_CASE(einsum_tiledarray_ak_bk_ab) {
einsum_tiledarray_check<2,2,2>(
random(11,7),
random(13,7),
random<SparsePolicy>(11,7),
random<SparsePolicy>(13,7),
"ak,bk->ab"
);
}

BOOST_AUTO_TEST_CASE(einsum_tiledarray_ka_bk_ba) {
einsum_tiledarray_check<2,2,2>(
random(7,11),
random(13,7),
random<SparsePolicy>(7,11),
random<SparsePolicy>(13,7),
"ka,bk->ba"
);
}

BOOST_AUTO_TEST_CASE(einsum_tiledarray_abi_cdi_cdab) {
einsum_tiledarray_check<3,3,4>(
random(21,22,3),
random(24,25,3),
random<SparsePolicy>(21,22,3),
random<SparsePolicy>(24,25,3),
"abi,cdi->cdab"
);
}

BOOST_AUTO_TEST_CASE(einsum_tiledarray_icd_ai_abcd) {
einsum_tiledarray_check<3,3,4>(
random(3,12,13),
random(14,15,3),
random<SparsePolicy>(3,12,13),
random<SparsePolicy>(14,15,3),
"icd,bai->abcd"
);
}

BOOST_AUTO_TEST_CASE(einsum_tiledarray_cdji_ibja_abcd) {
einsum_tiledarray_check<4,4,4>(
random(14,15,3,5),
random(5,12,3,13),
random<SparsePolicy>(14,15,3,5),
random<SparsePolicy>(5,12,3,13),
"cdji,ibja->abcd"
);
}

BOOST_AUTO_TEST_CASE(einsum_tiledarray_hai_hbi_hab) {
einsum_tiledarray_check<3,3,3>(
random(7,14,3),
random(7,15,3),
random<SparsePolicy>(7,14,3),
random<SparsePolicy>(7,15,3),
"hai,hbi->hab"
);
einsum_tiledarray_check<3,3,3>(
sparse_zero(7,14,3),
sparse_zero(7,15,3),
"hai,hbi->hab"
);
}

BOOST_AUTO_TEST_CASE(einsum_tiledarray_iah_hib_bha) {
einsum_tiledarray_check<3,3,3>(
random(7,14,3),
random(3,7,15),
random<SparsePolicy>(7,14,3),
random<SparsePolicy>(3,7,15),
"iah,hib->bha"
);
einsum_tiledarray_check<3,3,3>(
sparse_zero(7,14,3),
sparse_zero(3,7,15),
"iah,hib->bha"
);
}

BOOST_AUTO_TEST_CASE(einsum_tiledarray_iah_hib_abh) {
einsum_tiledarray_check<3,3,3>(
random(7,14,3),
random(3,7,15),
random<SparsePolicy>(7,14,3),
random<SparsePolicy>(3,7,15),
"iah,hib->abh"
);
einsum_tiledarray_check<3,3,3>(
sparse_zero(7,14,3),
sparse_zero(3,7,15),
"iah,hib->abh"
);
}

BOOST_AUTO_TEST_CASE(einsum_tiledarray_hai_hibc_habc) {
einsum_tiledarray_check<3,4,4>(
random<SparsePolicy>(9,3,11),
random<SparsePolicy>(9,11,5,7),
"hai,hibc->habc"
);
einsum_tiledarray_check<3,4,4>(
sparse_zero(9,3,11),
sparse_zero(9,11,5,7),
"hai,hibc->habc"
);
}

BOOST_AUTO_TEST_CASE(einsum_tiledarray_hi_hi_h) {
einsum_tiledarray_check<2,2,1>(
random(7,14),
random(7,14),
random<SparsePolicy>(7,14),
random<SparsePolicy>(7,14),
"hi,hi->h"
);
einsum_tiledarray_check<2,2,1>(
sparse_zero(7,14),
sparse_zero(7,14),
"hi,hi->h"
);
}

BOOST_AUTO_TEST_CASE(einsum_tiledarray_hji_jih_hj) {
einsum_tiledarray_check<3,3,2>(
random(14,7,5),
random(7,5,14),
random<SparsePolicy>(14,7,5),
random<SparsePolicy>(7,5,14),
"hji,jih->hj"
);
einsum_tiledarray_check<3,3,2>(
sparse_zero(14,7,5),
sparse_zero(7,5,14),
"hji,jih->hj"
);
}
Expand Down