Skip to content

Commit

Permalink
cleanup and improve coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
maxscheurer committed Apr 22, 2021
1 parent 3189761 commit af39665
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 24 deletions.
10 changes: 4 additions & 6 deletions adcc/ExcitedStates.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,10 @@ def __iadd__(self, other):
for k in other:
assert isinstance(k, EnergyCorrection)
self.__add_energy_correction(k)
elif isinstance(other, dict):
for k in other:
corr_fun = other[k]
assert callable(corr_fun)
enc = EnergyCorrection(k, corr_fun)
self.__add_energy_correction(enc)
else:
raise TypeError("Can only add EnergyCorrection (or list"
" of EnergyCorrection) to"
f" ExcitedState, not '{type(other)}'")
return self

def __add__(self, other):
Expand Down
52 changes: 34 additions & 18 deletions adcc/test_AdcMatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
## ---------------------------------------------------------------------
import adcc
import unittest
import pytest
import itertools
import numpy as np

Expand Down Expand Up @@ -226,6 +227,19 @@ def test_extra_term(self):
# we get 2*shift on the diagonal :(
ones = matrix.diagonal().ones_like()

with pytest.raises(TypeError):
AdcBlock(0, 0)
with pytest.raises(TypeError):
AdcBlock(lambda x: x, "fail")
with pytest.raises(TypeError):
abc = AdcBlock(lambda x: x, ones)
abc.add_block("fail")

with pytest.raises(TypeError):
AdcExtraTerm(matrix, "fail")
with pytest.raises(TypeError):
AdcExtraTerm(matrix, {"fail": "not_callable"})

def __shift_ph(hf, mp, intermediates):
def apply(invec):
return adcc.AmplitudeVector(ph=shift * invec.ph)
Expand All @@ -241,24 +255,26 @@ def apply(invec):
matrix, {'ph_ph': __shift_ph, 'pphh_pphh': __shift_pphh}
)
shifted_2 = matrix + extra
assert_allclose(
shifted.diagonal().ph.to_ndarray(),
shifted_2.diagonal().ph.to_ndarray(),
atol=1e-12
)
assert_allclose(
shifted.diagonal().pphh.to_ndarray(),
shifted_2.diagonal().pphh.to_ndarray(),
atol=1e-12
)
vec = adcc.guess_zero(matrix)
vec.set_random()
ref = shifted @ vec
ret = shifted_2 @ vec
diff_s = ref.ph - ret.ph
diff_d = ref.pphh - ret.pphh
assert np.max(np.abs(diff_s.to_ndarray())) < 1e-12
assert np.max(np.abs(diff_d.to_ndarray())) < 1e-12
shifted_3 = extra + matrix
for manual in [shifted_2, shifted_3]:
assert_allclose(
shifted.diagonal().ph.to_ndarray(),
manual.diagonal().ph.to_ndarray(),
atol=1e-12
)
assert_allclose(
shifted.diagonal().pphh.to_ndarray(),
manual.diagonal().pphh.to_ndarray(),
atol=1e-12
)
vec = adcc.guess_zero(matrix)
vec.set_random()
ref = shifted @ vec
ret = manual @ vec
diff_s = ref.ph - ret.ph
diff_d = ref.pphh - ret.pphh
assert np.max(np.abs(diff_s.to_ndarray())) < 1e-12
assert np.max(np.abs(diff_d.to_ndarray())) < 1e-12


@expand_test_templates(testcases)
Expand Down
3 changes: 3 additions & 0 deletions adcc/test_ExcitedStates.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def base_test(self, system, method, kind):

with pytest.raises(ValueError):
state_corrected += cc2
with pytest.raises(TypeError):
state_corrected += 1

state_corrected2 = state_corrected + cc3
for i in range(state.size):
Expand All @@ -94,6 +96,7 @@ def base_test(self, system, method, kind):
corr = state.excitation_energy[i] ** 2 + 2.0 - 42.0
assert_allclose(state.excitation_energy[i] + corr,
state_corrected2.excitation_energy[i])
state_corrected2.describe()


class TestDataFrameExport(unittest.TestCase, Runners):
Expand Down

0 comments on commit af39665

Please sign in to comment.