Skip to content

Commit

Permalink
Fix and cleanup gradient tests
Browse files Browse the repository at this point in the history
  • Loading branch information
emerali committed Dec 20, 2019
1 parent 57ebcd8 commit 0638cef
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 327 deletions.
73 changes: 36 additions & 37 deletions qucumber/utils/training_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _kron_mult(matrices, x):
return y


def rotate_psi(nn_state, basis, space, unitaries, psi=None):
def rotate_psi(nn_state, basis, space, unitaries=None, psi=None):
r"""A function that rotates the reconstructed wavefunction to a different
basis.
Expand All @@ -67,12 +67,13 @@ def rotate_psi(nn_state, basis, space, unitaries, psi=None):
else psi.to(dtype=torch.double, device=nn_state.device)
)

unitaries = unitaries if unitaries else nn_state.unitary_dict
unitaries = {k: v.to(device=nn_state.device) for k, v in unitaries.items()}
us = [unitaries[b] for b in basis]
return _kron_mult(us, psi)


def rotate_rho(nn_state, basis, space, unitaries, rho=None):
def rotate_rho(nn_state, basis, space, unitaries=None, rho=None):
r"""Computes the density matrix rotated into some basis
:param nn_state: The density matrix neural network state.
Expand All @@ -97,6 +98,7 @@ def rotate_rho(nn_state, basis, space, unitaries, rho=None):
else rho.to(dtype=torch.double, device=nn_state.device)
)

unitaries = unitaries if unitaries else nn_state.unitary_dict
unitaries = {k: v.to(device=nn_state.device) for k, v in unitaries.items()}
us = [unitaries[b] for b in basis]

Expand Down Expand Up @@ -164,9 +166,8 @@ def fidelity(nn_state, target, space, **kwargs):
def NLL(nn_state, samples, space, bases=None, **kwargs):
r"""A function for calculating the negative log-likelihood (NLL).
:param nn_state: The neural network state (i.e. complex wavefunction or
positive wavefunction).
:type nn_state: qucumber.nn_states.WaveFunctionBase
:param nn_state: The neural network state.
:type nn_state: qucumber.nn_states.NeuralStateBase
:param samples: Samples to compute the NLL on.
:type samples: torch.Tensor
:param space: The basis elements of the Hilbert space of the system :math:`\mathcal{H}`.
Expand All @@ -178,36 +179,43 @@ def NLL(nn_state, samples, space, bases=None, **kwargs):
:returns: The Negative Log-Likelihood.
:rtype: float
"""
psi_r = torch.zeros(
2, 1 << nn_state.num_visible, dtype=torch.double, device=nn_state.device
)
NLL = 0.0
Z = nn_state.compute_normalization(space)
Z = nn_state.normalization(space)

if bases is None:
nn_probs = nn_state.probability(samples, Z)
NLL = -torch.sum(probs_to_logits(nn_probs))
NLL_ = -torch.mean(probs_to_logits(nn_probs)).item()
return NLL_
else:
unitary_dict = nn_state.unitary_dict
NLL_ = 0.0

indices = 2 ** (torch.arange(nn_state.num_visible, 0, -1) - 1)
indices = torch.mv(samples, indices.to(samples)).long()

for i in range(len(samples)):
# Check whether the sample was measured the reference basis
is_reference_basis = True
for j in range(nn_state.num_visible):
if bases[i][j] != "Z":
is_reference_basis = False
break

if is_reference_basis is True:
nn_probs = nn_state.probability(samples[i], Z)
NLL -= torch.sum(probs_to_logits(nn_probs))
NLL_ -= torch.sum(probs_to_logits(nn_probs))
else:
psi_r = rotate_psi(nn_state, bases[i], space, unitary_dict)
# Get the index value of the sample state
ind = 0
for j in range(nn_state.num_visible):
if samples[i, nn_state.num_visible - j - 1] == 1:
ind += pow(2, j)
probs_r = cplx.norm_sqr(psi_r[:, ind]) / Z
NLL -= probs_to_logits(probs_r).item()
return (NLL / float(len(samples))).item()
ind = indices[i]

if isinstance(nn_state, WaveFunctionBase):
psi_r = rotate_psi(nn_state, bases[i], space)
probs_r = cplx.norm_sqr(psi_r[:, ind]) / Z
NLL_ -= probs_to_logits(probs_r).item()
else:
rho_r = rotate_rho(nn_state, bases[i], space)
probs_r = torch.diagonal(cplx.real(rho_r))[ind] / Z
NLL_ -= probs_to_logits(probs_r).item()

return NLL_ / float(len(samples))


def _single_basis_KL(target_probs, nn_probs):
Expand All @@ -217,7 +225,8 @@ def _single_basis_KL(target_probs, nn_probs):


def KL(nn_state, target, space, bases=None, **kwargs):
r"""A function for calculating the total KL divergence.
r"""A function for calculating the KL divergence averaged over every given
basis.
.. math:: KL(P_{target} \vert P_{RBM}) = -\sum_{x \in \mathcal{H}} P_{target}(x)\log(\frac{P_{RBM}(x)}{P_{target}(x)})
Expand Down Expand Up @@ -261,42 +270,32 @@ def KL(nn_state, target, space, bases=None, **kwargs):
KL += _single_basis_KL(target_probs, nn_probs)

elif isinstance(nn_state, WaveFunctionBase):
unitary_dict = nn_state.unitary_dict

for basis in bases:
psi_r = rotate_psi(nn_state, basis, space, unitary_dict)
if isinstance(target, dict):
target_psi_r = target[basis]
assert target_psi_r.dim() == 2, "target must be a complex vector!"
else:
assert target.dim() == 2, "target must be a complex vector!"
target_psi_r = rotate_psi(nn_state, basis, space, psi=target)

target_psi_r = rotate_psi(
nn_state, basis, space, unitary_dict, psi=target
)

psi_r = rotate_psi(nn_state, basis, space)
nn_probs_r = (cplx.absolute_value(psi_r) ** 2) / Z
target_probs_r = cplx.absolute_value(target_psi_r) ** 2

KL += _single_basis_KL(target_probs_r, nn_probs_r)

KL /= float(len(bases))
else:
unitary_dict = nn_state.unitary_dict

for basis in bases:
rho_r = rotate_rho(nn_state, basis, space, unitary_dict) / Z
if isinstance(target, dict):
target_rho_r = target[basis]
assert target_rho_r.dim() == 3, "target must be a complex matrix!"
else:
assert target.dim() == 3, "target must be a complex matrix!"
target_rho_r = rotate_rho(nn_state, basis, space, rho=target)

target_rho_r = rotate_rho(
nn_state, basis, space, unitary_dict, rho=target
)

nn_probs_r = torch.diagonal(cplx.real(rho_r))
rho_r = rotate_rho(nn_state, basis, space)
nn_probs_r = torch.diagonal(cplx.real(rho_r)) / Z
target_probs_r = torch.diagonal(cplx.real(target_rho_r))

KL += _single_basis_KL(target_probs_r, nn_probs_r)
Expand Down

0 comments on commit 0638cef

Please sign in to comment.