Skip to content

Commit

Permalink
more refactoring, make adcc usable again
Browse files Browse the repository at this point in the history
  • Loading branch information
maxscheurer committed Apr 30, 2021
1 parent 27d09f1 commit 2bd764f
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 11 deletions.
24 changes: 18 additions & 6 deletions adcc/AdcMatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ class AdcMatrix(AdcMatrixlike):
"adc3": dict(ph_ph=3, ph_pphh=2, pphh_ph=2, pphh_pphh=1), # noqa: E501
}

def __init__(self, method, hf_or_mp, block_orders=None, intermediates=None):
def __init__(self, method, hf_or_mp, block_orders=None, intermediates=None,
diagonal_precomputed=None):
"""
Initialise an ADC matrix.
Expand All @@ -101,6 +102,8 @@ def __init__(self, method, hf_or_mp, block_orders=None, intermediates=None):
If not set, defaults according to the selected ADC method are chosen.
intermediates : adcc.Intermediates or NoneType
Allows to pass intermediates to re-use to this class.
diagonal_precomputed: adcc.AmplitudeVector
Allows to pass a pre-computed diagonal, for internal use only.
"""
if isinstance(hf_or_mp, (libadcc.ReferenceState,
libadcc.HartreeFockSolution_i)):
Expand Down Expand Up @@ -159,9 +162,18 @@ def __init__(self, method, hf_or_mp, block_orders=None, intermediates=None):
}
# TODO Rename to self.block in 0.16.0
self.blocks_ph = {bl: blocks[bl].apply for bl in blocks}
self.__diagonal = sum(bl.diagonal for bl in blocks.values()
if bl.diagonal)
self.__diagonal.evaluate()
if diagonal_precomputed:
if not isinstance(diagonal_precomputed, AmplitudeVector):
raise TypeError("diagonal_precomputed needs to be"
" an AmplitudeVector.")
if diagonal_precomputed.needs_evaluation:
raise ValueError("diagonal_precomputed must already"
" be evaluated.")
self.__diagonal = diagonal_precomputed
else:
self.__diagonal = sum(bl.diagonal for bl in blocks.values()
if bl.diagonal)
self.__diagonal.evaluate()
self.__init_space_data(self.__diagonal)

def __iadd__(self, other):
Expand Down Expand Up @@ -208,10 +220,10 @@ def __add__(self, other):
"""
if not isinstance(other, AdcExtraTerm):
return NotImplemented
# NOTE: re-computes the (expensive) diagonal...
ret = AdcMatrix(self.method, self.ground_state,
block_orders=self.block_orders,
intermediates=self.intermediates)
intermediates=self.intermediates,
diagonal_precomputed=self.diagonal())
ret += other
return ret

Expand Down
4 changes: 4 additions & 0 deletions adcc/AmplitudeVector.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ def evaluate(self):
t.evaluate()
return self

@property
def needs_evaluation(self):
return any(t.needs_evaluation for k, t in self.items())

def ones_like(self):
"""Return an empty AmplitudeVector of the same shape and symmetry"""
return AmplitudeVector(**{k: t.ones_like() for k, t in self.items()})
Expand Down
8 changes: 8 additions & 0 deletions adcc/test_AdcMatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,14 @@ def test_extra_term(self):
with pytest.raises(TypeError):
matrix_adc1 += 42
matrix = adcc.AdcMatrix("adc2", ground_state)

with pytest.raises(TypeError):
adcc.AdcMatrix("adc2", ground_state,
diagonal_precomputed=42)
with pytest.raises(ValueError):
adcc.AdcMatrix("adc2", ground_state,
diagonal_precomputed=matrix.diagonal() + 42)

shift = -0.3
shifted = AdcMatrixShifted(matrix, shift)
# TODO: need to do this to differentiate between
Expand Down
4 changes: 3 additions & 1 deletion adcc/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,9 +542,11 @@ def setup_environment(matrix, environment):
)
elif environment and not hf.environment:
raise InputError(
"environment specified, but no environment"
"Environment specified, but no environment"
" was found in reference state."
)
elif not hf.environment:
environment = False

if isinstance(environment, bool):
environment = {"ptss": True, "ptlr": True} if environment else {}
Expand Down
2 changes: 1 addition & 1 deletion examples/pna/psi4_adc2_pna_6w_pol_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,5 @@

# Run an adc2 calculation:
state = adcc.adc2(wfn, n_singlets=5, conv_tol=1e-8,
solvent_scheme=['ptss', 'ptlr'])
environment=True)
print(state.describe())
6 changes: 3 additions & 3 deletions examples/pna/pyscf_adc2_pna_6w_pol_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@

# model the solvent through perturbative corrections
state_pt = adcc.adc2(scfres, n_singlets=5, conv_tol=1e-5,
solvent_scheme=['ptss', 'ptlr'])
environment=['ptss', 'ptlr'])

# now model the solvent through linear-response/postscf coupling
# now model the solvent through linear-response coupling
# in the ADC matrix, re-using the matrix from previous run.
# This will modify state_pt.matrix
state_lr = adcc.run_adc(state_pt.matrix, n_singlets=5, conv_tol=1e-5,
solvent_scheme='postscf')
environment='linear_response')

print(state_pt.describe())
print(state_lr.describe())

0 comments on commit 2bd764f

Please sign in to comment.