Skip to content

Commit

Permalink
Reduce code duplication in NeuralState subclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
emerali committed Dec 10, 2019
1 parent 691f5ce commit 459fee9
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 461 deletions.
13 changes: 11 additions & 2 deletions docs/quantum_states.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ Complex WaveFunction
:inherited-members:
:show-inheritance:

Density Matrix
------------------------------

.. autoclass:: qucumber.nn_states.DensityMatrix
:members:
:show-inheritance:

Abstract WaveFunction
------------------------------

Expand All @@ -29,9 +36,11 @@ Abstract WaveFunction
:members:
:show-inheritance:

Density Matrix
Abstract NeuralState
------------------------------

.. autoclass:: qucumber.nn_states.DensityMatrix
.. note:: |AbstractClassNote|

.. autoclass:: qucumber.nn_states.NeuralStateBase
:members:
:show-inheritance:
6 changes: 3 additions & 3 deletions qucumber/callbacks/metric_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ class MetricEvaluator(CallbackBase):
This callback is called at the end of each epoch.
.. note::
Since callbacks are given to :func:`fit<qucumber.nn_states.WaveFunctionBase.fit>`
Since callbacks are given to :func:`fit<qucumber.nn_states.NeuralStateBase.fit>`
as a list, they will be called in a deterministic order. It is
therefore recommended that instances of
:class:`MetricEvaluator<MetricEvaluator>` be among the first callbacks in
the list passed to :func:`fit<qucumber.nn_states.WaveFunctionBase.fit>`,
the list passed to :func:`fit<qucumber.nn_states.NeuralStateBase.fit>`,
as one would often use it in conjunction with other callbacks like
:class:`EarlyStopping<EarlyStopping>` which may depend on
:class:`MetricEvaluator<MetricEvaluator>` having been called.
Expand All @@ -39,7 +39,7 @@ class MetricEvaluator(CallbackBase):
metric(s).
:type period: int
:param metrics: A dictionary of callables where the keys are the names of
the metrics and the callables take the WaveFunction being trained
the metrics and the callables take the NeuralState being trained
as their positional argument, along with some keyword
arguments. The metrics are evaluated and put into an internal
dictionary structure resembling the structure of `metrics`.
Expand Down
6 changes: 3 additions & 3 deletions qucumber/callbacks/observable_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ class ObservableEvaluator(CallbackBase):
This callback is called at the end of each epoch.
.. note::
Since callback are given to :func:`fit<qucumber.nn_states.WaveFunctionBase.fit>`
Since callback are given to :func:`fit<qucumber.nn_states.NeuralStateBase.fit>`
as a list, they will be called in a deterministic order. It is
therefore recommended that instances of
:class:`ObservableEvaluator<ObservableEvaluator>` be among the first callbacks in
the list passed to :func:`fit<qucumber.nn_states.WaveFunctionBase.fit>`,
the list passed to :func:`fit<qucumber.nn_states.NeuralStateBase.fit>`,
as one would often use it in conjunction with other callbacks like
:class:`EarlyStopping<EarlyStopping>` which may depend on
:class:`ObservableEvaluator<ObservableEvaluator>` having been called.
Expand All @@ -80,7 +80,7 @@ class ObservableEvaluator(CallbackBase):
observables(s).
:type period: int
:param observables: A list of Observables. Observable statistics are
evaluated by sampling the WaveFunction. Note that
evaluated by sampling the NeuralState. Note that
observables that have the same name will conflict,
and precedence will be given to the one which appears
later in the list.
Expand Down
5 changes: 4 additions & 1 deletion qucumber/nn_states/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from .neural_state import NeuralStateBase

from .wavefunction import WaveFunctionBase
from .complex_wavefunction import ComplexWaveFunction
from .positive_wavefunction import PositiveWaveFunction
from .wavefunction import WaveFunctionBase

from .density_matrix import DensityMatrix
51 changes: 7 additions & 44 deletions qucumber/nn_states/complex_wavefunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,44 +268,6 @@ def rotated_gradient(self, basis, sites, sample):

return grad

def gradient(self, sample, basis):
r"""Compute the gradient of a sample, measured in different bases.
:param basis: A set of bases.
:type basis: numpy.ndarray
:param sample: A sample to compute the gradient of.
:type sample: numpy.ndarray
:returns: A list of 2 tensors containing the parameters of each of the
internal RBMs.
:rtype: list[torch.Tensor]
"""
basis = np.array(list(basis)) # list is silly, but works for now
rot_sites = np.where(basis != "Z")[0]
if rot_sites.size == 0:
grad = [
self.rbm_am.effective_energy_gradient(sample), # Real
0.0, # Imaginary
]
else:
grad = self.rotated_gradient(basis, rot_sites, sample)
return grad

def compute_normalization(self, space):
r"""Compute the normalization constant of the wavefunction.
.. math::
Z_{\bm{\lambda}}=
\sqrt{\sum_{\bm{\sigma}}|\psi_{\bm{\lambda\mu}}|^2}=
\sqrt{\sum_{\bm{\sigma}} p_{\bm{\lambda}}(\bm{\sigma})}
:param space: A rank 2 tensor of the entire visible space.
:type space: torch.Tensor
"""
return super().compute_normalization(space)

def fit(
self,
data,
Expand All @@ -320,6 +282,9 @@ def fit(
time=False,
callbacks=None,
optimizer=torch.optim.SGD,
optimizer_args=None,
scheduler=None,
scheduler_args=None,
**kwargs
):
if input_bases is None:
Expand All @@ -340,16 +305,14 @@ def fit(
time=time,
callbacks=callbacks,
optimizer=optimizer,
optimizer_args=optimizer_args,
scheduler=scheduler,
scheduler_args=scheduler_args,
**kwargs
)

def save(self, location, metadata=None):
metadata = metadata if metadata else {}
metadata["unitary_dict"] = self.unitary_dict
super().save(location, metadata=metadata)

@staticmethod
def autoload(location, gpu=True):
def autoload(location, gpu=False):
state_dict = torch.load(location)
wvfn = ComplexWaveFunction(
unitary_dict=state_dict["unitary_dict"],
Expand Down

0 comments on commit 459fee9

Please sign in to comment.