Skip to content

Commit

Permalink
Introduce abstract NeuralState class
Browse files Browse the repository at this point in the history
  • Loading branch information
emerali committed Dec 10, 2019
1 parent 1e7e1b1 commit d2c4239
Show file tree
Hide file tree
Showing 9 changed files with 675 additions and 593 deletions.
61 changes: 30 additions & 31 deletions qucumber/nn_states/complex_wavefunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,11 @@ class ComplexWaveFunction(WaveFunctionBase):
_device = None

def __init__(
self, num_visible, num_hidden=None, unitary_dict=None, gpu=True, module=None
self, num_visible, num_hidden=None, unitary_dict=None, gpu=False, module=None
):
if gpu and torch.cuda.is_available():
warnings.warn(
(
"Using ComplexWaveFunction on GPU is not recommended due to poor "
"performance compared to CPU. In the future, ComplexWaveFunction "
"will default to using CPU, even if a GPU is available."
),
"Using ComplexWaveFunction on GPU is not recommended due to poor performance compared to CPU.",
ResourceWarning,
2,
)
Expand All @@ -68,25 +64,18 @@ def __init__(
self.device = torch.device("cpu")

if module is None:
self.rbm_am = BinaryRBM(
int(num_visible),
int(num_hidden) if num_hidden else int(num_visible),
gpu=gpu,
)
self.rbm_ph = BinaryRBM(
int(num_visible),
int(num_hidden) if num_hidden else int(num_visible),
gpu=gpu,
)
self.rbm_am = BinaryRBM(num_visible, num_hidden, gpu=gpu)
self.rbm_ph = BinaryRBM(num_visible, num_hidden, gpu=gpu)
else:
_warn_on_missing_gpu(gpu)
self.rbm_am = module.to(self.device)
self.rbm_am.device = self.device
self.rbm_ph = module.to(self.device).clone()
self.rbm_ph.device = self.device

self.num_visible = int(num_visible)
self.num_hidden = int(num_hidden) if num_hidden else self.num_visible
self.num_visible = self.rbm_am.num_visible
self.num_hidden = self.rbm_am.num_hidden
self.device = self.rbm_am.device

self.unitary_dict = unitary_dict if unitary_dict else unitaries.create_dict()
self.unitary_dict = {
Expand Down Expand Up @@ -169,20 +158,16 @@ def psi(self, v):
each visible state
:rtype: torch.Tensor
"""
# vectors/tensors of shape (len(v),)
amplitude, phase = self.amplitude(v), self.phase(v)

# complex vector; shape: (2, len(v))
psi = torch.zeros(
(2,) + amplitude.shape, dtype=torch.double, device=self.device
)

# elementwise products
psi[0] = amplitude * phase.cos() # real part
psi[1] = amplitude * phase.sin() # imaginary part
return psi
return super().psi(v)

def init_gradient(self, basis, sites):
r"""Initializes all required variables for gradient computation
:param basis: The bases of the measurements
:type basis: numpy.ndarray
:param sites: The sites where the measurements are not
in the computational basis
"""
Upsi = torch.zeros(2, dtype=torch.double, device=self.device)
Us = torch.stack([self.unitary_dict[b] for b in basis[sites]]).cpu().numpy()
rotated_grad = [
Expand All @@ -194,6 +179,20 @@ def init_gradient(self, basis, sites):
return Upsi, Us, rotated_grad

def rotated_gradient(self, basis, sites, sample):
r"""Computes the gradients rotated into the measurement basis
:param basis: The bases in which the measurement is made
:type basis: numpy.ndarray
:param sites: The sites where the measurements are not made
in the computational basis
:type sites: numpy.ndarray
:param sample: The measurement (either 0 or 1)
:type sample: torch.Tensor
:returns: A list of two tensors, representing the rotated gradients
of the amplitude and phase RBMS
:rtype: list[torch.Tensor, torch.Tensor]
"""
Upsi, Us, rotated_grad = self.init_gradient(basis, sites)
int_sample = sample[sites].round().int().cpu().numpy()
ints_size = np.arange(sites.size)
Expand Down Expand Up @@ -269,7 +268,7 @@ def rotated_gradient(self, basis, sites, sample):

return grad

def gradient(self, basis, sample):
def gradient(self, sample, basis):
r"""Compute the gradient of a sample, measured in different bases.
:param basis: A set of bases.
Expand Down

0 comments on commit d2c4239

Please sign in to comment.