Skip to content

Commit

Permalink
Implement diagonal function and extended einsum
Browse files Browse the repository at this point in the history
  • Loading branch information
mfherbst committed Jun 17, 2020
1 parent eac300e commit 9e7957d
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 14 deletions.
2 changes: 1 addition & 1 deletion adcc/adc_pp/state2state_transition_dm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import libadcc

DISPATCH = {}
DISPATCH = {} # None implemented


def state2state_transition_dm(method, ground_state, amplitude_from,
Expand Down
8 changes: 4 additions & 4 deletions adcc/adc_pp/state_diffdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ def diffdm_adc2(mp, amplitude, intermediates):
p1_oo = dm[b.oo].evaluate() # ADC(1) diffdm
p1_vv = dm[b.vv].evaluate() # ADC(1) diffdm

# Compute zeroth order doubles contributions
# Zeroth order doubles contributions
p2_oo = -einsum("ikab,jkab->ij", u2, u2)
p2_vv = einsum("ijac,ijbc->ab", u2, u2)
p2_ov = -2 * einsum("jb,ijab->ia", u1, u2)
p2_ov = -2 * einsum("jb,ijab->ia", u1, u2).evaluate()

# ADC(2) ISR intermediate (TODO Move to intermediates)
ru1 = einsum("ijab,jb->ia", t2, u1)
ru1 = einsum("ijab,jb->ia", t2, u1).evaluate()

# Compute second-order contributions to the density matrix
dm[b.oo] = ( # adc2_p_oo
Expand Down Expand Up @@ -127,7 +127,7 @@ def diffdm_cvs_adc2(mp, amplitude, intermediates):
p0_vv = intermediates.cv_p_vv
p1_vv = dm[b.vv].evaluate() # ADC(1) diffdm

# Compute zeroth order doubles contributions
# Zeroth order doubles contributions
p2_ov = -sqrt(2) * einsum("jb,ijab->ia", u1, u2)
p2_vo = -sqrt(2) * einsum("ijab,jb->ai", u2, u1)
p2_oo = -einsum("ljab,kjab->kl", u2, u2)
Expand Down
93 changes: 92 additions & 1 deletion adcc/opt_einsum_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,99 @@
__all__ = ["register_with_opt_einsum"]


def _dispatch_diagonal(subscript, outstring, operand):
# Do the diagonal call one index at a time.
char = None
for c in subscript:
if subscript.count(c) > 1:
char = c
break
assert char is not None
indices = [i for (i, c) in enumerate(subscript) if c == char]

# Diagonal pushes the diagonalised index back
newoperand = operand.diagonal(*indices)
newsubscript = "".join(c for c in subscript if c != char) + char
if len(newsubscript) > len(outstring):
# Recurse if more letters for which diagonals are to be extracted
return _dispatch_diagonal(newsubscript, outstring, newoperand)
elif newsubscript == outstring:
return newoperand
else:
# Transpose newsubscript -> outstring
permutation = tuple(map(newsubscript.index, outstring))
return newoperand.transpose(permutation)


def _fallback_einsum(einsum_str, *operands, **kwargs):
# A fallback implementation of einsum in adcc,
# which deals with a few cases opt_einsum cannot deal with
from .functions import einsum
from opt_einsum.parser import gen_unused_symbols

operands = list(operands)
subscripts = einsum_str.split("->")[0].split(",")
outstr = einsum_str.split("->")[1]

# If there are any diagonal extractions, which can be done in the operands,
# do them first.
for i in range(len(subscripts)):
sub = subscripts[i]
cdiagonal = set(c for c in sub if sub.count(c) > 1 and c in outstr)
ctrace = set(c for c in sub if sub.count(c) > 1 and c not in outstr)
if ctrace:
raise NotImplementedError("Partial traces (e.g. contractions "
"'iaib->ab') are not yet supported "
"in adcc.einsum.")
if cdiagonal:
# Do any possible diagonal extraction first
outstring = "".join(c for c in sub if c not in cdiagonal)
outstring += "".join(cdiagonal)
operands[i] = _dispatch_diagonal(subscripts[i], outstring,
operands[i])
subscripts[i] = outstring

if len(subscripts) == 1:
# At this point all which is left should be a permutation.
assert all(c in outstr for c in subscripts[0])
permutation = tuple(map(subscripts[0].index, outstr))
return operands[0].transpose(permutation)
elif len(subscripts) == 2:
# Should the diagonal of a contraction be extracted, e.g. il,laib->aib
diagonal_chars = [a for a in subscripts[0]
if a in subscripts[1] and a in outstr]

if not diagonal_chars:
# Try another round of einsum
return einsum(",".join(subscripts) + "->" + outstr, *operands)

# Replace one of the duplicate characters in the input
# and prepend it to output
replacers = list(gen_unused_symbols(outstr + "".join(subscripts),
len(diagonal_chars)))
newoutstr = "".join(replacers) + outstr
for (old, new) in zip(diagonal_chars, replacers):
subscripts[0] = subscripts[0].replace(old, new)

# Check we are not creating an infinite loop:
assert ",".join(subscripts) + "->" + newoutstr != einsum_str

# Run einsum doing the partial contraction
# (well actually we should directly do tensordot)
res = einsum(",".join(subscripts) + "->" + newoutstr, *operands)

# Run _dispatch_diagonal with the result to form the requested diagonal
return _dispatch_diagonal("".join(diagonal_chars) + outstr, outstr, res)
else:
raise NotImplementedError("Fallback einsum not implemented for more than "
"two operators")


def register_with_opt_einsum():
import libadcc

from opt_einsum.backends.dispatch import EVAL_CONSTS_BACKENDS
from opt_einsum.backends.dispatch import (EVAL_CONSTS_BACKENDS,
_cached_funcs, _has_einsum)

def libadcc_evaluate_constants(const_arrays, expr):
# Compute the partial expression tree of the inputs
Expand All @@ -39,3 +128,5 @@ def libadcc_evaluate_constants(const_arrays, expr):
new_ops = [None if x is None else libadcc.evaluate(x) for x in new_ops]
return new_ops, new_contraction_list
EVAL_CONSTS_BACKENDS["libadcc"] = libadcc_evaluate_constants
_has_einsum["libadcc"] = False
_cached_funcs[('einsum', 'libadcc')] = _fallback_einsum
17 changes: 15 additions & 2 deletions adcc/test_Tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@
import unittest
import numpy as np

from .misc import expand_test_templates
from numpy.testing import assert_allclose

from adcc import direct_sum, einsum, empty_like, nosym_like
from adcc.testdata.cache import cache

from .misc import expand_test_templates


@expand_test_templates(["h2o_sto3g", "cn_sto3g"])
class TestTensor(unittest.TestCase):
Expand Down Expand Up @@ -130,3 +129,17 @@ def test_nontrivial_direct_sum(self):
res = direct_sum("ab-i-c->iacb", einsum("ijac,ijcb->ab", oovv, oovv),
oeo, oev)
assert_allclose(res.to_ndarray(), ref, rtol=1e-10, atol=1e-14)

def test_nontrivial_diagonal(self):
refstate = cache.refstate["cn_sto3g"]
mtcs = [nosym_like(refstate.eri("o1o1v1v1")).set_random(),
nosym_like(refstate.eri("o1v1o1v1")).set_random(),
nosym_like(refstate.eri("o1o1v1v1")).set_random()]
mnps = [m.to_ndarray() for m in mtcs]

res = einsum("ijab,ibkc,kjad->cd", mtcs[0], mtcs[1], mtcs[2]).diagonal()
ref = np.einsum("ijab,ibkc,kjad->cd", mnps[0], mnps[1],
mnps[2]).diagonal()

assert res.needs_evaluation
assert_allclose(res.to_ndarray(), ref, rtol=1e-10, atol=1e-14)
63 changes: 58 additions & 5 deletions adcc/test_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,14 @@
from adcc import einsum, empty_like, nosym_like
from adcc.testdata.cache import cache

import pytest


class TestEinsum(unittest.TestCase):
def base_test(self, contr, a, b):
a.set_random()
b.set_random()
ref = np.einsum(contr, a.to_ndarray(), b.to_ndarray())
out = einsum(contr, a, b)
def base_test(self, contr, *arr):
arr = [a.set_random() for a in arr]
ref = np.einsum(contr, *[a.to_ndarray() for a in arr])
out = einsum(contr, *arr)
assert_allclose(out.to_ndarray(), ref, rtol=1e-10, atol=1e-14)

def test_einsum_1_1_0(self):
Expand Down Expand Up @@ -127,3 +128,55 @@ def test_einsum_4_4_2(self): # (3, 4, 4) in C++
a = empty_like(refstate.eri("o1v1o1v1"))
b = empty_like(refstate.eri("o1o1o1v1"))
self.base_test("iajb,jikb->ka", a, b)

def test_diagonal_1(self):
refstate = cache.refstate["cn_sto3g"]
a = empty_like(refstate.eri("o1v1o1v1"))
self.base_test("iaia->ia", a)

def test_diagonal_2(self):
refstate = cache.refstate["cn_sto3g"]
a = empty_like(refstate.eri("o1v1o1v1"))
self.base_test("iaia->ai", a)

def test_diagonal_3(self):
refstate = cache.refstate["h2o_sto3g"]
a = empty_like(refstate.ovov)
self.base_test("iaja->aij", a)

def test_diagonal_4(self):
refstate = cache.refstate["h2o_sto3g"]
a = empty_like(refstate.ovov)
self.base_test("iaja->ija", a)

def test_diagonal_5(self):
refstate = cache.refstate["cn_sto3g"]
a = empty_like(refstate.ovov)
b = empty_like(refstate.foo)
self.base_test("laib,il->aib", a, b)

def test_diagonal_6(self):
refstate = cache.refstate["h2o_sto3g"]
a = empty_like(refstate.fov)
b = empty_like(refstate.fvo)
self.base_test("ai,ia->i", a, b)

def test_diagonal_7(self):
refstate = cache.refstate["h2o_sto3g"]
a = empty_like(refstate.ovov)
b = empty_like(refstate.fvv)
self.base_test("iaib,ba->ia", a, b)

def test_partial_trace(self):
refstate = cache.refstate["h2o_sto3g"]
a = empty_like(refstate.ovov)
b = empty_like(refstate.fvv)
with pytest.raises(NotImplementedError, match=r"Partial traces.*"):
einsum("iaib,ba->a", a, b)

def test_thc(self):
refstate = cache.refstate["h2o_sto3g"]
a = nosym_like(refstate.foo)
b = nosym_like(refstate.fov)
c = nosym_like(refstate.foo)
self.base_test("ij,ia,ik->jka", a, b, c)
12 changes: 12 additions & 0 deletions extension/export_Tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,17 @@ static py::object tensordot_2(ten_ptr a, ten_ptr b, size_t axes) {

static py::object tensordot_3(ten_ptr a, ten_ptr b) { return tensordot_2(a, b, 2); }

static ten_ptr Tensor_diagonal(ten_ptr ten, py::args permutations) {
std::vector<size_t> axes;
if (py::len(permutations) == 0) {
axes.push_back(0);
axes.push_back(1);
} else {
for (auto itm : permutations) axes.push_back(itm.cast<size_t>());
}
return ten->diagonal(axes);
}

static ten_ptr direct_sum(ten_ptr a, ten_ptr b) { return a->direct_sum(b); }

static double Tensor_trace_1(std::string subscripts, const Tensor& tensor) {
Expand Down Expand Up @@ -446,6 +457,7 @@ void export_Tensor(py::module& m) {
.def("set_mask", &adcc::Tensor::set_mask,
"Set all elements corresponding to an index mask, which is given by a "
"string eg. 'iijkli' sets elements T_{iijkli}")
.def("diagonal", &Tensor_diagonal)
.def("copy", &Tensor::copy, "Returns a deep copy of the tensor.")
.def("dot", &Tensor_dot)
.def("dot", &Tensor_dot_list)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def finalize_options(self):

# Version of the python bindings and adcc python package.
__version__ = "0.15.0"
adccore_version = ("0.14.2", "") # (base version, unstable postfix)
adccore_version = ("0.14.3", "") # (base version, unstable postfix)


def is_conda_build():
Expand Down

0 comments on commit 9e7957d

Please sign in to comment.