Skip to content

Commit

Permalink
composable AdcExtraTerm stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
maxscheurer committed Apr 19, 2021
1 parent 518e89a commit 72b0d7c
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 116 deletions.
89 changes: 0 additions & 89 deletions .clang-format

This file was deleted.

67 changes: 47 additions & 20 deletions adcc/AdcMatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,29 @@
from .AmplitudeVector import AmplitudeVector


class AdcExtraTerm:
"""
"""
def __init__(self, matrix, blocks):
self.ground_state = matrix.ground_state
self.reference_state = matrix.reference_state
self.intermediates = matrix.intermediates

self.blocks = {}
if not isinstance(blocks, dict):
raise TypeError("blocks needs to "
" be a dict.")
for space in blocks:
block_fun = blocks[space]
if not callable(block_fun):
raise TypeError("Items in additional_blocks "
"must be callable.")
block = block_fun(
self.reference_state, self.ground_state, self.intermediates
)
self.blocks[space] = block


class AdcMatrixlike:
"""
Base class marker for all objects like ADC matrices.
Expand All @@ -50,8 +73,7 @@ class AdcMatrix(AdcMatrixlike):
"adc3": dict(ph_ph=3, ph_pphh=2, pphh_ph=2, pphh_pphh=1), # noqa: E501
}

def __init__(self, method, hf_or_mp, block_orders=None, intermediates=None,
additional_blocks=None):
def __init__(self, method, hf_or_mp, block_orders=None, intermediates=None):
"""
Initialise an ADC matrix.
Expand Down Expand Up @@ -85,6 +107,7 @@ def __init__(self, method, hf_or_mp, block_orders=None, intermediates=None,
self.mospaces = hf_or_mp.reference_state.mospaces
self.is_core_valence_separated = method.is_core_valence_separated
self.ndim = 2
self.extra_terms = []

self.intermediates = intermediates
if self.intermediates is None:
Expand Down Expand Up @@ -113,35 +136,39 @@ def __init__(self, method, hf_or_mp, block_orders=None, intermediates=None,
# Build the blocks and diagonals
with self.timer.record("build"):
variant = None
if method.is_core_valence_separated:
if self.is_core_valence_separated:
variant = "cvs"
self.blocks_ph = { # TODO Rename to self.block in 0.16.0
block: ppmatrix.block(self.ground_state, block.split("_"),
order=order, intermediates=self.intermediates,
variant=variant)
for block, order in block_orders.items() if order is not None
for block, order in self.block_orders.items() if order is not None
}
if additional_blocks is not None:
if not isinstance(additional_blocks, dict):
raise TypeError("additional_blocks needs to "
" be a dict.")
for space in additional_blocks:
if space not in self.blocks_ph:
raise ValueError("Can only add blocks"
" to existing matrix blocks.")
block_fun = additional_blocks[space]
if not callable(block_fun):
raise TypeError("Items in additional_blocks "
"must be callable.")
block = block_fun(
self.reference_state, self.ground_state, self.intermediates
)
self.blocks_ph[space].add_block(block)
self.__diagonal = sum(bl.diagonal for bl in self.blocks_ph.values()
if bl.diagonal)
self.__diagonal.evaluate()
self.__init_space_data(self.__diagonal)

def __iadd__(self, other):
assert isinstance(other, AdcExtraTerm)
assert all(k in self.blocks_ph for k in other.blocks)
for sp in other.blocks:
self.blocks_ph[sp].add_block(other.blocks[sp])
# TODO: always re-computes the (expensive) diagonal...
with self.timer.record("build"):
self.__diagonal = sum(bl.diagonal for bl in self.blocks_ph.values()
if bl.diagonal)
self.__diagonal.evaluate()
self.extra_terms.append(other)
return self

def __add__(self, other):
ret = AdcMatrix(self.method, self.ground_state,
block_orders=self.block_orders,
intermediates=self.intermediates)
ret += other
return ret

def __init_space_data(self, diagonal):
"""Update the cached data regarding the spaces of the ADC matrix"""
self.axis_spaces = {}
Expand Down
9 changes: 5 additions & 4 deletions adcc/backends/test_backends_polembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from ..testdata.cache import qchem_data, tmole_data
from ..testdata.static_data import pe_potentials

from ..AdcMatrix import AdcExtraTerm
from ..adc_pp.solvent import block_ph_ph_0_pe

try:
Expand Down Expand Up @@ -95,10 +96,10 @@ def template_pe_coupling_formaldehyde(self, basis, method, backend):
pe_options=pe_options)
assert_allclose(scfres.energy_scf, tm_result["energy_scf"], atol=1e-8)

matrix = adcc.AdcMatrix(
method, scfres,
additional_blocks={'ph_ph': block_ph_ph_0_pe}
)
matrix = adcc.AdcMatrix(method, scfres)
solvent = AdcExtraTerm(matrix, {'ph_ph': block_ph_ph_0_pe})
matrix += solvent
assert len(matrix.extra_terms)

assert_allclose(
matrix.ground_state.energy(2),
Expand Down
9 changes: 8 additions & 1 deletion adcc/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,20 @@ def run_adc(data_or_matrix, n_states=None, kind="any", conv_tol=None,
if eigensolver is None:
eigensolver = "davidson"

# add solvent coupling terms to matrix?
# matrix += pe_coupling_term
# solvent_scheme: scf, pt, ptlr, ptss, ...

diagres = diagonalise_adcmatrix(
matrix, n_states, kind, guesses=guesses, n_guesses=n_guesses,
n_guesses_doubles=n_guesses_doubles, conv_tol=conv_tol, output=output,
eigensolver=eigensolver, **solverargs)
exstates = ExcitedStates(diagres)
exstates.kind = kind
exstates.spin_change = spin_change
exstates.spin_change = spin_change

# if solvent.ptss:
# exstates += ptss_correction
return exstates


Expand Down
2 changes: 1 addition & 1 deletion libadcc/TensorImpl.hh
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class TensorImpl : public Tensor {
std::string describe_expression(std::string stage = "unoptimised") const override;

/** Convert object to btensor for use in libtensor functions. */
explicit operator libtensor::btensor<N, scalar_type>&() { return *libtensor_ptr(); }
explicit operator libtensor::btensor<N, scalar_type> &() { return *libtensor_ptr(); }

/** Return inner btensor object
*
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def libadcc_extension():
def is_conda_build():
return (
os.environ.get("CONDA_BUILD", None) == "1"
or os.environ.get("CONDA_EXE", None)
or os.environ.get("CONDA_EXE", None) is not None
)


Expand Down

0 comments on commit 72b0d7c

Please sign in to comment.