Skip to content

Commit

Permalink
revert adcblock, patch function
Browse files Browse the repository at this point in the history
  • Loading branch information
maxscheurer committed Apr 29, 2021
1 parent 00be33f commit 96b6f32
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 64 deletions.
23 changes: 14 additions & 9 deletions adcc/AdcMatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,15 @@ def __init__(self, method, hf_or_mp, block_orders=None, intermediates=None):
variant = None
if self.is_core_valence_separated:
variant = "cvs"
self.blocks_ph = { # TODO Rename to self.block in 0.16.0
blocks = {
block: ppmatrix.block(self.ground_state, block.split("_"),
order=order, intermediates=self.intermediates,
variant=variant)
for block, order in self.block_orders.items() if order is not None
}
self.__diagonal = sum(bl.diagonal for bl in self.blocks_ph.values()
# TODO Rename to self.block in 0.16.0
self.blocks_ph = {bl: blocks[bl].apply for bl in blocks}
self.__diagonal = sum(bl.diagonal for bl in blocks.values()
if bl.diagonal)
self.__diagonal.evaluate()
self.__init_space_data(self.__diagonal)
Expand All @@ -176,12 +178,15 @@ def __iadd__(self, other):
raise ValueError("Can only add to blocks of"
" AdcMatrix that already exist.")
for sp in other.blocks:
ob = other.blocks[sp]
self.blocks_ph[sp].add_block(ob)
diag = sum(bl.diagonal for bl in other.blocks.values()
if bl.diagonal)
orig_app = self.blocks_ph[sp]
other_app = other.blocks[sp].apply
def patched_apply(ampl, original=orig_app, other=other_app):
return sum(app(ampl) for app in (original, other))
self.blocks_ph[sp] = patched_apply
other_diagonal = sum(bl.diagonal for bl in other.blocks.values()
if bl.diagonal)
# iadd does not work with numbers
self.__diagonal = self.__diagonal + diag
self.__diagonal = self.__diagonal + other_diagonal
self.__diagonal.evaluate()
self.extra_terms.append(other)
return self
Expand Down Expand Up @@ -295,7 +300,7 @@ def block_apply(self, block, tensor):
with self.timer.record(f"apply/{block}"):
outblock, inblock = block.split("_")
ampl = AmplitudeVector(**{inblock: tensor})
ret = self.blocks_ph[block].apply(ampl)
ret = self.blocks_ph[block](ampl)
return getattr(ret, outblock)

@timed_member_call()
Expand All @@ -304,7 +309,7 @@ def matvec(self, v):
Compute the matrix-vector product of the ADC matrix
with an excitation amplitude and return the result.
"""
return sum(block.apply(v) for block in self.blocks_ph.values())
return sum(block(v) for block in self.blocks_ph.values())

def rmatvec(self, v):
# ADC matrix is symmetric
Expand Down
57 changes: 10 additions & 47 deletions adcc/adc_pp/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
##
## ---------------------------------------------------------------------
from math import sqrt
from collections import namedtuple

from adcc import block as b
from adcc.functions import direct_sum, einsum, zeros_like
Expand All @@ -38,53 +39,15 @@
# really so much our focus.


class AdcBlock:
def __init__(self, apply, diagonal):
"""AdcBlock contains the matrix apply and diagonal
routines for a specific ADC matrix block and is used to compose
the dispatch routines for :py:class:`AdcMatrix`
Parameters
----------
apply : callable
function mapping an AmplitudeVector to the contribution of
this block to the result of applying the ADC matrix
diagonal : AmplitudeVector, int, float
expression to the diagonal of the ADC matrix from this block
"""
if not callable(apply):
raise TypeError("apply needs to be callable.")
if not isinstance(diagonal, (AmplitudeVector, int, float)):
raise TypeError("diagonal needs to be an AmplitudeVector or a number.")
self._applies = [apply]
self._diagonals = [diagonal]

def add_block(self, other):
"""Adds another :py:class:`AdcBlock` to this AdcBlock.
Parameters
----------
other : AdcBlock
block to be added
"""
if not isinstance(other, AdcBlock):
raise TypeError("other must be of type AdcBlock.")
self._applies.append(other.apply)
self._diagonals.append(other.diagonal)

def apply(self, ampl):
"""Applies the block to an input AmplitudeVector
Parameters
----------
ampl : AmplitudeVector
Input AmplitudeVector
"""
return sum(app(ampl) for app in self._applies)

@property
def diagonal(self):
return sum(self._diagonals)
#
# Dispatch routine
#
"""
`apply` is a function mapping an AmplitudeVector to the contribution of this
block to the result of applying the ADC matrix. `diagonal` is an `AmplitudeVector`
containing the expression to the diagonal of the ADC matrix from this block.
"""
AdcBlock = namedtuple("AdcBlock", ["apply", "diagonal"])


def block(ground_state, spaces, order, variant=None, intermediates=None):
Expand Down
8 changes: 0 additions & 8 deletions adcc/test_AdcMatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,14 +230,6 @@ def test_extra_term(self):
# we get 2*shift on the diagonal :(
ones = matrix.diagonal().ones_like()

with pytest.raises(TypeError):
AdcBlock(0, 0)
with pytest.raises(TypeError):
AdcBlock(lambda x: x, "fail")
with pytest.raises(TypeError):
abc = AdcBlock(lambda x: x, ones)
abc.add_block("fail")

with pytest.raises(TypeError):
AdcExtraTerm(matrix, "fail")
with pytest.raises(TypeError):
Expand Down

0 comments on commit 96b6f32

Please sign in to comment.