Skip to content

Commit

Permalink
Absorb density matrix training staitistics into the main ones
Browse files Browse the repository at this point in the history
  • Loading branch information
emerali committed Dec 18, 2019
1 parent 59c45bb commit fbf9dd9
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 123 deletions.
225 changes: 108 additions & 117 deletions qucumber/utils/training_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,36 +18,10 @@
import numpy as np
from scipy.linalg import sqrtm

from qucumber.nn_states import WaveFunctionBase
import qucumber.utils.cplx as cplx


def fidelity(nn_state, target_psi, space, **kwargs):
r"""Calculates the square of the overlap (fidelity) between the reconstructed
wavefunction and the true wavefunction (both in the computational basis).
.. math:: F = \vert \langle \psi_{RBM} \vert \psi_{target} \rangle \vert ^2
:param nn_state: The neural network state (i.e. complex wavefunction or
positive wavefunction).
:type nn_state: qucumber.nn_states.WaveFunctionBase
:param target_psi: The true wavefunction of the system.
:type target_psi: torch.Tensor
:param space: The basis elements of the Hilbert space of the system :math:`\mathcal{H}`.
The ordering of the basis elements must match with the ordering of the
coefficients given in `target_psi`.
:type space: torch.Tensor
:param \**kwargs: Extra keyword arguments that may be passed. Will be ignored.
:returns: The fidelity.
:rtype: float
"""
Z = nn_state.compute_normalization(space)
target_psi = target_psi.to(nn_state.device)
psi = nn_state.psi(space) / Z.sqrt()
F = cplx.inner_prod(target_psi, psi)
return cplx.absolute_value(F).pow_(2).item()


def _kron_mult(matrices, x):
n = [m.size()[0] for m in matrices]
l, r = np.prod(n), 1 # noqa: E741
Expand Down Expand Up @@ -132,6 +106,61 @@ def rotate_rho(nn_state, basis, space, unitaries, rho=None):
return rho_r


def fidelity(nn_state, target, space, **kwargs):
r"""Calculates the square of the overlap (fidelity) between the reconstructed
state and the true state (both in the computational basis).
.. math::
F = \vert \langle \psi_{RBM} \vert \psi_{target} \rangle \vert ^2
= \left( \tr \lbrack \sqrt{ \sqrt{\rho_{RBM}} \rho_{target} \sqrt{\rho_{RBM}} } \rbrack \right) ^ 2
:param nn_state: The neural network state.
:type nn_state: qucumber.nn_states.NeuralStateBase
:param target: The true state of the system.
:type target: torch.Tensor
:param space: The basis elements of the Hilbert space of the system :math:`\mathcal{H}`.
The ordering of the basis elements must match with the ordering of the
coefficients given in `target`.
:type space: torch.Tensor
:param \**kwargs: Extra keyword arguments that may be passed. Will be ignored.
:returns: The fidelity.
:rtype: float
"""
Z = nn_state.normalization(space)
target = target.to(nn_state.device)

if isinstance(nn_state, WaveFunctionBase):
assert target.dim() == 2, "target must be a complex vector!"

psi = nn_state.psi(space) / Z.sqrt()
F = cplx.inner_prod(target, psi)
return cplx.absolute_value(F).pow_(2).item()
else:
assert target.dim() == 3, "target must be a complex matrix!"

rho = nn_state.rho(space, space) / Z
arg_real = cplx.real(rho).numpy()
arg_imag = cplx.imag(rho).numpy()

rho_rbm_ = arg_real + 1j * arg_imag

arg_real = cplx.real(target).numpy()
arg_imag = cplx.imag(target).numpy()

target_ = arg_real + 1j * arg_imag

sqrt_rho_rbm = sqrtm(rho_rbm_)
prod = np.matmul(sqrt_rho_rbm, np.matmul(target_, sqrt_rho_rbm))

# Instead of sqrt'ing then taking the trace, we compute the eigenvals,
# sqrt those, and then sum them up. This is a bit more efficient.
eigvals = np.linalg.eigvalsh(prod)
trace = np.sum(np.sqrt(eigvals).real) # imaginary parts should be zero
return trace ** 2


def NLL(nn_state, samples, space, bases=None, **kwargs):
r"""A function for calculating the negative log-likelihood (NLL).
Expand Down Expand Up @@ -181,21 +210,26 @@ def NLL(nn_state, samples, space, bases=None, **kwargs):
return (NLL / float(len(samples))).item()


def KL(nn_state, target_psi, space, bases=None, **kwargs):
def _single_basis_KL(target_probs, nn_probs):
return torch.sum(target_probs * probs_to_logits(target_probs)) - torch.sum(
target_probs * probs_to_logits(nn_probs)
)


def KL(nn_state, target, space, bases=None, **kwargs):
r"""A function for calculating the total KL divergence.
.. math:: KL(P_{target} \vert P_{RBM}) = -\sum_{x \in \mathcal{H}} P_{target}(x)\log(\frac{P_{RBM}(x)}{P_{target}(x)})
:param nn_state: The neural network state (i.e. complex wavefunction or
positive wavefunction).
:type nn_state: qucumber.nn_states.WaveFunctionBase
:param target_psi: The true wavefunction of the system. Can be a dictionary
with each value being the wavefunction represented in a
different basis, and the key identifying the basis.
:type target_psi: torch.Tensor or dict(str, torch.Tensor)
:param nn_state: The neural network state.
:type nn_state: qucumber.nn_states.NeuralStateBase
:param target: The true state (wavefunction or density matrix) of the system.
Can be a dictionary with each value being the state
represented in a different basis, and the key identifying the basis.
:type target: torch.Tensor or dict(str, torch.Tensor)
:param space: The basis elements of the Hilbert space of the system :math:`\mathcal{H}`.
The ordering of the basis elements must match with the ordering of the
coefficients given in `target_psi`.
coefficients given in `target`.
:type space: torch.Tensor
:param bases: An array of unique bases. If given, the KL divergence will be
computed for each basis and the average will be returned.
Expand All @@ -205,111 +239,68 @@ def KL(nn_state, target_psi, space, bases=None, **kwargs):
:returns: The KL divergence.
:rtype: float
"""
psi_r = torch.zeros(
2, 1 << nn_state.num_visible, dtype=torch.double, device=nn_state.device
)
KL = 0.0

if isinstance(target_psi, dict):
target_psi = {k: v.to(nn_state.device) for k, v in target_psi.items()}
if isinstance(target, dict):
target = {k: v.to(nn_state.device) for k, v in target.items()}
if bases is None:
bases = list(target_psi.keys())
bases = list(target.keys())
else:
assert set(bases) == set(
target_psi.keys()
target.keys()
), "Given bases must match the keys of the target_psi dictionary."
else:
target_psi = target_psi.to(nn_state.device)
target = target.to(nn_state.device)

Z = nn_state.normalization(space)

Z = nn_state.compute_normalization(space)
if bases is None:
target_probs = cplx.absolute_value(target_psi) ** 2
target_probs = cplx.absolute_value(target) ** 2
nn_probs = nn_state.probability(space, Z)
KL += torch.sum(target_probs * probs_to_logits(target_probs))
KL -= torch.sum(target_probs * probs_to_logits(nn_probs))
else:

KL += _single_basis_KL(target_probs, nn_probs)

elif isinstance(nn_state, WaveFunctionBase):
unitary_dict = nn_state.unitary_dict

for basis in bases:
psi_r = rotate_psi(nn_state, basis, space, unitary_dict)
if isinstance(target_psi, dict):
target_psi_r = target_psi[basis]
if isinstance(target, dict):
target_psi_r = target[basis]
assert target_psi_r.dim() == 2, "target must be a complex vector!"
else:
assert target.dim() == 2, "target must be a complex vector!"

target_psi_r = rotate_psi(
nn_state, basis, space, unitary_dict, target_psi
nn_state, basis, space, unitary_dict, psi=target
)

probs_r = (cplx.absolute_value(psi_r) ** 2) / Z
nn_probs_r = (cplx.absolute_value(psi_r) ** 2) / Z
target_probs_r = cplx.absolute_value(target_psi_r) ** 2

KL += torch.sum(target_probs_r * probs_to_logits(target_probs_r))
KL -= torch.sum(target_probs_r * probs_to_logits(probs_r))
KL /= float(len(bases))

return KL.item()


def density_matrix_fidelity(nn_state, target, space, **kwargs):
r"""Calculate the fidelity of the reconstructed density matrix
given the exact target density matrix
:param nn_state: The neural network state (i.e. current density matrix)
:type nn_state: qucumber.nn_states.DensityMatrix
:param target: The true density matrix of the system
:type target: torch.Tensor
:param space: The basis elements of the visible space
:type space: torch.Tensor
:param \**kwargs: Extra keyword arguments that may be passed.
Will be ignored.
:returns: The fidelity
:rtype: float
"""
Z = nn_state.normalization(space)
rho = nn_state.rho(space, space) / Z
arg_real = cplx.real(rho).numpy()
arg_imag = cplx.imag(rho).numpy()

rho_rbm_ = arg_real + 1j * arg_imag

arg_real = cplx.real(target).numpy()
arg_imag = cplx.imag(target).numpy()

target_ = arg_real + 1j * arg_imag

sqrt_rho_rbm = sqrtm(rho_rbm_)
KL += _single_basis_KL(target_probs_r, nn_probs_r)

arg = sqrtm(np.matmul(sqrt_rho_rbm, np.matmul(target_, sqrt_rho_rbm)))
return np.trace(arg).real


def density_matrix_KL(nn_state, target, bases, space):
"""Computes the KL divergence between the current and target density matrix
:param target: The target density matrix
:type target: torch.Tensor
:param bases: The bases in which measurement is made
:type bases: numpy.ndarray
:param space: The space of the visible states
:type space: torch.Tensor
:returns: The KL divergence
:rtype: float
"""
Z = nn_state.normalization(space)
unitary_dict = nn_state.unitary_dict
rho_r_diag = torch.zeros(2 ** nn_state.num_visible, dtype=torch.double)
target_rho_r_diag = torch.zeros_like(rho_r_diag)
KL /= float(len(bases))
else:
unitary_dict = nn_state.unitary_dict

KL = 0.0
for basis in bases:
rho_r = rotate_rho(nn_state, basis, space, unitary_dict) / Z
if isinstance(target, dict):
target_rho_r = target[basis]
assert target_rho_r.dim() == 3, "target must be a complex matrix!"
else:
assert target.dim() == 3, "target must be a complex matrix!"

for basis in bases:
rho_r = rotate_rho(nn_state, basis, space, unitary_dict) / Z
target_rho_r = rotate_rho(nn_state, basis, space, unitary_dict, rho=target)
target_rho_r = rotate_rho(
nn_state, basis, space, unitary_dict, rho=target
)

rho_r_diag = torch.diagonal(cplx.real(rho_r))
target_rho_r_diag = torch.diagonal(cplx.real(target_rho_r))
nn_probs_r = torch.diagonal(cplx.real(rho_r))
target_probs_r = torch.diagonal(cplx.real(target_rho_r))

KL += torch.sum(target_rho_r_diag * probs_to_logits(target_rho_r_diag))
KL -= torch.sum(target_rho_r_diag * probs_to_logits(rho_r_diag))
KL += _single_basis_KL(target_probs_r, nn_probs_r)

KL /= float(len(bases))
KL /= float(len(bases))

return KL.item()
8 changes: 4 additions & 4 deletions tests/grads_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,8 @@ def transform_bases(self, bases_data):

# return torch.tensor(np.array(num_gradNLL), dtype=torch.double).to(param)

def compute_numerical_KL(self, target, bases, space):
return ts.density_matrix_KL(self.nn_state, target, bases, space)
def compute_numerical_KL(self, target, space, bases):
return ts.KL(self.nn_state, target, space, bases=bases)

def algorithmic_gradKL(self, data_samples, data_bases, space, **kwargs):
return self.nn_state.compute_exact_gradients(data_samples, space, data_bases)
Expand All @@ -265,10 +265,10 @@ def numeric_gradKL(self, param, target, space, bases, eps, **kwargs):

for i in range(len(param)):
param[i] += eps
KL_p = self.compute_numerical_KL(target, bases, space)
KL_p = self.compute_numerical_KL(target, space, bases)

param[i] -= 2 * eps
KL_m = self.compute_numerical_KL(target, bases, space)
KL_m = self.compute_numerical_KL(target, space, bases)

param[i] += eps
num_gradKL.append((KL_p - KL_m) / (2 * eps))
Expand Down
4 changes: 2 additions & 2 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def test_trainingpositive(request):
MetricEvaluator(
log_every,
{"Fidelity": ts.fidelity, "KL": ts.KL},
target_psi=target_psi,
target=target_psi,
space=space,
verbose=True,
)
Expand Down Expand Up @@ -364,7 +364,7 @@ def test_trainingcomplex(request, vectorized):
MetricEvaluator(
log_every,
{"Fidelity": ts.fidelity, "KL": ts.KL},
target_psi=target_psi,
target=target_psi,
bases=bases,
space=space,
verbose=True,
Expand Down

0 comments on commit fbf9dd9

Please sign in to comment.