Skip to content

Commit

Permalink
Pythonise AmplitudeVector
Browse files Browse the repository at this point in the history
  • Loading branch information
mfherbst committed Dec 1, 2020
1 parent b4e1d24 commit 02cd6f9
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 63 deletions.
5 changes: 3 additions & 2 deletions adcc/AdcMatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,10 +430,11 @@ def compute_matvec(self, ampl):

res = [bl.apply(ampl) for bl in self.__blocks.values()]
ph = sum(v.ph for v in res if "s" in v.blocks and v.ph)
pphh = None
if "d" in self.blocks:
pphh = sum(v.pphh for v in res if "d" in v.blocks and v.pphh)
return AmplitudeVector(ph=ph, pphh=pphh)
return AmplitudeVector(ph=ph, pphh=pphh)
else:
return AmplitudeVector(ph=ph)

def has_block(self, block): # TODO This should be deprecated
return block in self.blocks
Expand Down
142 changes: 81 additions & 61 deletions adcc/AmplitudeVector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,91 +20,110 @@
## along with adcc. If not, see <http://www.gnu.org/licenses/>.
##
## ---------------------------------------------------------------------
import warnings

BLOCK_LABELS = ["s", "d", "t"]


# TODO Extend AmplitudeVector to cases where only the doubles block is present
class AmplitudeVector:
def __init__(self, *tensors, ph=None, pphh=None):
"""Initialise an AmplitudeVector from some blocks"""
# TODO ph and pphh are new-style constructors, they should
# take over passing their tensors into a dict.
if len(tensors) == 0:
if pphh is not None:
if ph is None:
ph = 0
self.tensors = [ph, pphh]
else:
assert ph is not None
self.tensors = [ph]
class AmplitudeVector(dict):
def __init__(self, *args, **kwargs):
"""
Construct an AmplitudeVector. Typical use cases are
``AmplitudeVector(ph=tensor_singles, pphh=tensor_doubles)``.
"""
if args:
warnings.warn("Using the list interface of AmplitudeVector is "
"deprecated and will be removed in version 0.15.1. Use "
"AmplitudeVector(ph=tensor_singles, pphh=tensor_doubles) "
"instead.")
if len(args) == 1:
super().__init__(ph=args[0])
elif len(args) == 2:
super().__init__(ph=args[0], pphh=args[1])
else:
assert ph is None and pphh is None
self.tensors = list(tensors)
super().__init__(**kwargs)

@property
def ph(self): # TODO temporary
return self.tensors[0]
def __getattr__(self, key):
if self.__contains__(key):
return self.__getitem__(key)
raise AttributeError

@property
def pphh(self): # TODO temporary
return self.tensors[1]

# TODO Attach some information about this Amplitude, e.g.
# is it CVS?
def blocks(self):
warnings.warn("The blocks function will change behaviour in 0.15.1.")
if sorted(self.blocks_ph) == ["ph", "pphh"]:
return ["s", "d"]
if sorted(self.blocks_ph) == ["pphh"]:
return ["d"]
elif sorted(self.blocks_ph) == ["ph"]:
return ["s"]
elif sorted(self.blocks_ph) == []:
return []
else:
raise NotImplementedError(self.blocks_ph)

@property
def blocks(self):
return [BLOCK_LABELS[i] for i in range(len(self.tensors))]
def blocks_ph(self):
"""
Return the blocks where are used inside the vector.
Note: This is a temporary name. The function will be removed in 0.15.1.
"""
return self.keys()

def __getitem__(self, index):
if isinstance(index, int):
return self.tensors[index]
elif isinstance(index, str):
if index not in BLOCK_LABELS:
raise ValueError("Invalid index, either a block string "
"like s,d,t, ... or an integer index "
"are expected.")
return self.__getitem__(BLOCK_LABELS.index(index))
if index in (0, 1, "s", "d"):
warnings.warn("Using the list interface of AmplitudeVector is "
"deprecated and will be removed in version 0.15.1. Use "
"block labels like 'ph', 'pphh' instead.")
if index in (0, "s"):
return self.__getitem__("ph")
elif index in (1, "d"):
return self.__getitem__("pphh")
else:
raise KeyError(index)
else:
return super().__getitem__(index)

def __setitem__(self, index, item):
if isinstance(index, int):
self.tensors[index] = item
elif isinstance(index, str):
if index not in BLOCK_LABELS:
raise ValueError("Invalid index, either a block string "
"like s,d,t, ... or an integer index "
"are expected.")
return self.__setitem__(BLOCK_LABELS.index(index), item)
if index in (0, 1, "s", "d"):
warnings.warn("Using the list interface of AmplitudeVector is "
"deprecated and will be removed in version 0.15.1. Use "
"block labels like 'ph', 'pphh' instead.")
if index in (0, "s"):
return self.__setitem__("ph", item)
elif index in (1, "d"):
return self.__setitem__("pphh", item)
else:
raise KeyError(index)
else:
super().__setitem__(index, item)

def copy(self):
"""Return a copy of the AmplitudeVector"""
return AmplitudeVector(*tuple(t.copy() for t in self.tensors))
return AmplitudeVector(**{k: t.copy() for k, t in self.items()})

def evaluate(self):
for t in self.tensors:
for t in self.values():
t.evaluate()
return self

def ones_like(self):
"""Return an empty AmplitudeVector of the same shape and symmetry"""
return AmplitudeVector(*tuple(t.ones_like() for t in self.tensors))
return AmplitudeVector(**{k: t.ones_like() for k, t in self.items()})

def empty_like(self):
"""Return an empty AmplitudeVector of the same shape and symmetry"""
return AmplitudeVector(*tuple(t.empty_like() for t in self.tensors))
return AmplitudeVector(**{k: t.empty_like() for k, t in self.items()})

def nosym_like(self):
"""Return an empty AmplitudeVector of the same shape and symmetry"""
return AmplitudeVector(*tuple(t.nosym_like() for t in self.tensors))
return AmplitudeVector(**{k: t.nosym_like() for k, t in self.items()})

def zeros_like(self):
"""Return an AmplitudeVector of the same shape and symmetry with
all elements set to zero"""
return AmplitudeVector(*tuple(t.zeros_like() for t in self.tensors))
return AmplitudeVector(**{k: t.zeros_like() for k, t in self.items()})

def set_random(self):
for t in self.tensors:
for t in self.values():
t.set_random()
return self

Expand All @@ -116,10 +135,9 @@ def dot(self, other):
if isinstance(other, list):
# Make a list where the first index is all singles parts,
# the second is all doubles parts and so on
alltensors = [[av[b] for av in other] for b in self.blocks]
return sum(t.dot(ots) for t, ots in zip(self.tensors, alltensors))
return sum(self[b].dot([av[b] for av in other]) for b in self.keys())
else:
return sum(t.dot(ot) for t, ot in zip(self.tensors, other.tensors))
return sum(self[b].dot(other[b]) for b in self.keys())

def __matmul__(self, other):
if isinstance(other, AmplitudeVector):
Expand All @@ -131,15 +149,17 @@ def __matmul__(self, other):

def __forward_to_blocks(self, fname, other):
if isinstance(other, AmplitudeVector):
ret = tuple(getattr(t, fname)(ot)
for t, ot in zip(self.tensors, other.tensors))
if sorted(other.blocks_ph) != sorted(self.blocks_ph):
raise ValueError("Blocks of both AmplitudeVector objects "
"need to agree")
ret = {k: getattr(tensor, fname)(other[k])
for k, tensor in self.items()}
else:
ret = tuple(getattr(t, fname)(other) for t in self.tensors)

if any(r == NotImplemented for r in ret):
ret = {k: getattr(tensor, fname)(other) for k, tensor in self.items()}
if any(r == NotImplemented for r in ret.values()):
return NotImplemented
else:
return AmplitudeVector(*ret)
return AmplitudeVector(**ret)

def __mul__(self, other):
return self.__forward_to_blocks("__mul__", other)
Expand Down Expand Up @@ -175,4 +195,4 @@ def __itruediv__(self, other):
return self.__forward_to_blocks("__itruediv__", other)

def __repr__(self):
return "AmplitudeVector(blocks=" + ",".join(self.blocks) + ")"
return "AmplitudeVector(" + "=..., ".join(self.blocks_ph) + "=...)"

0 comments on commit 02cd6f9

Please sign in to comment.