In [None]:
# Packages
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax import random
import optax
from functools import reduce

# Functions
import QM_Spin_Systems as QSS
import LHV_Setup as LHV
import Training as TRG
import Plotting as PLT
import Qudits as QDT

key = random.PRNGKey(42) # initial key for random number generator

In [None]:
# Optimize LHV model for single Werner state (runtime ~ 10 seconds on a single NVIDIA Quadro RTX 6000 GPU)

N_spins = 2 # Number of spins, 2 for 2-spin Werner states

# Measurement rules
D = 5
PLHV, hidden_dim = LHV.PLHV_sh(D) # Alternatives: PLHV_Bell = PLHV_sh(1), PLHV_sh3 = PLHV_sh(3), PLHV_sh5 = PLHV_sh(5), 
                                  #               PLHV_sh5_old (different normalization), PLHV_spherical_harmonics_planar (PLHV_sh5 w/ z=0)
                                  #               PLHV_polynomial(D): Odd polynomial expansion up to degree D
PQM = LHV.PQM_Werner # For Werner states, state=visibility
                     # Alternatives: PQM (state given by its correlation matrix), 
                     #               PQM_XYZ (states with symmetries of "XYZ Hamiltonian" represented by correlators (xx, yy, zz))

# State = Werner State with the following visibility
visibility = .5

# Split key
key, key_init, key_train = random.split(key, (3, ))

# Hyper parameters
N_hidden = 2**14 # Hidden-variable cloud size
N_measures = 2**8 # Batch size (number of tuples of measurement directions sampled per gradient descent step)
N_measures_test = 2**14 # Batch size for final evaluation of the optimized cloud
N_steps = 1*10**4 # Number of gradient descent steps
hyper_params = [N_measures, N_measures_test, N_steps]

# Optimizer
N_steps_ft = 10**3 # Number of gradient descent steps with 10x smaller learning rate for fine tuning at the end of training
learning_rate = N_hidden * 2e-5 # The learning rate should be proportional to the hidden-variable cloud size
schedule = optax.piecewise_constant_schedule(learning_rate, boundaries_and_scales={N_steps-N_steps_ft: 0.1})
optimizer = optax.adam(learning_rate=schedule) # Adam optimizer

functions = [PLHV, PQM, LHV.KL, LHV.sample3Dprojective, optimizer] # Use the KL divergence as a distance measure and sample from all projective measurements
# Alternatives for distance measure: L2 (squared difference), L1 (difference, not suitable for training)
# Alternatives for sampling method: sample2Dprojective (only sample in xy plane), in general needs to be compatible with PQM

# Init and optimize the hidden-variable cloud
init = random.normal(key_init, (N_hidden, N_spins, hidden_dim)) # Independent gaussian initialization of all hidden-variable components in the cloud
# Obtain optimized hidden-variable cloud, the progression of the loss during training and the final loss
cloud, loss_values, test_loss = TRG.autoLHV(key_train, init, visibility, *hyper_params, *functions)

print("Final loss (KL): {:.2e}".format(test_loss))

# Test mean difference of measurement probabilities ( = L1 loss)
key, key_test = random.split(key)
N_measures_test = 2**14
test_functions = [PLHV, PQM, LHV.L1, LHV.sample3Dprojective]
test_loss_L1 = TRG.test_loss(key_test, cloud, visibility, N_measures_test, *test_functions)
print("Mean deviation (L1): {:.2e}".format(test_loss_L1))

In [None]:
# Plot the loss progression
dpi=150
size=15
stpsize = 10 # Only plot the loss every stpsize steps
plt.figure(figsize=(6, 3), dpi = dpi)
plt.plot(range(stpsize-1, N_steps, stpsize), loss_values[stpsize-1::stpsize], lw=.5, color="black")
plt.ylabel(r"Loss", size=size)
plt.xlabel(r"Steps", size=size)
plt.grid(alpha=0.5, zorder=-1)
plt.xscale("log")
plt.xlim((stpsize-1, 1.5*N_steps))
#plt.ylim((2e-8, 1e-7))
plt.yscale("log")
plt.show()

In [None]:
# Optimize LHV model for an arbitrary qubit state defined by its correlations (runtime ~ 20 seconds on a single NVIDIA Quadro RTX 6000 GPU)

N_spins = 1 # Number of qubits

# Measurement rules
PLHV, hidden_dim = LHV.PLHV_Bell() # = LHV.PLHV_sh(1)
PQM = LHV.PQM # requires a state defined by its correlation matrix

up = jnp.array([1., 0.])
rho = jnp.outer(jnp.conjugate(up), up) # density matrix for the state "up"
corrs = QSS.correlators(rho, N_spins) # correlation matrix

key, key_init, key_train = random.split(key, (3, ))

# Hyper parameters
N_hidden = 2**12
N_measures = 2**10
N_measures_test = 2**14
N_steps = 1*10**5
hyper_params = [N_measures, N_measures_test, N_steps]

# Optimizer
N_steps_ft = 1*10**4
learning_rate = N_hidden * 1e-5
schedule = optax.piecewise_constant_schedule(learning_rate, boundaries_and_scales={N_steps-N_steps_ft: 0.1})
optimizer = optax.adam(learning_rate=schedule)

functions = [PLHV, PQM, LHV.KL, LHV.sample3Dprojective, optimizer]

# Init and optimize
init = random.normal(key_init, (N_hidden, N_spins, hidden_dim))
cloud, loss_values, test_loss = TRG.autoLHV(key_train, init, corrs, *hyper_params, *functions)

print("Final loss: {:.2e}".format(test_loss))

In [None]:
# Plotting including the hidden-variable distribution (first three components normalized to the surface of a sphere)
import matplotlib

dpi = 150
size = 15
stpsize = 10

plt.figure(figsize=(5, 3*(N_spins)+3), dpi=dpi, constrained_layout=True)

ax00 = plt.subplot(N_spins+1, 1, 1)
ax00.plot(range(stpsize-1, N_steps, stpsize), loss_values[stpsize-1:N_steps:stpsize], lw=.5, color="black")
plt.xlabel(r"Steps", size=size-5)
plt.grid(alpha=0.5, zorder=-1)
plt.xscale("log")
plt.yscale("log")

N_grid = 100
phi_grid = jnp.linspace(-jnp.pi, jnp.pi, 2*N_grid)
theta_grid = jnp.linspace(-jnp.pi/2., jnp.pi/2., N_grid)
phi_grid, theta_grid = jnp.meshgrid(phi_grid, theta_grid)
normal_vectors = PLT.normal(phi_grid, theta_grid)
delta = 0.02

for j in range(N_spins):

    spin = cloud[:, j, :3]
    spin /= jnp.linalg.norm(spin, axis=-1, keepdims=True)

    phi_grid_scatter, theta_grid_scatter = PLT.S2_angles(spin)

    facecolors = PLT.gaussian_blurr(spin, normal_vectors, delta)
    facecolors /= jnp.max(facecolors)
    minimum = 0. # jnp.min(facecolors)
    maximum = 1. # jnp.max(facecolors)
    cnorm = matplotlib.colors.Normalize(vmin = minimum, vmax = maximum, clip=False)
    facecolors = matplotlib.cm.plasma_r(cnorm(facecolors))

    ax = plt.subplot(N_spins+1, 1, j+2, projection='mollweide')
    c = ax.pcolormesh(phi_grid, theta_grid, facecolors, cmap="plasma_r")
    ax.scatter(phi_grid_scatter, theta_grid_scatter, s=1, color="black", marker="o", alpha=.3)
    ax.set_xticks([-jnp.pi, -3.*jnp.pi/4., -jnp.pi/2, -jnp.pi/4., 0., jnp.pi/4., jnp.pi/2., 3.*jnp.pi/4., jnp.pi])
    ax.set_yticks([-0.99*jnp.pi/2., -jnp.pi/4., 0., jnp.pi/4., 0.99*jnp.pi/2.])
    ax.set_xticklabels(['', '', '', '', '', '', '', '', '']) #['$-\pi$', '$-\pi/2$', '0', '$\pi/2$', '$\pi$']
    ax.set_yticklabels(['', '', '', '', '']) # ['$-\pi/2$', '$-\pi/4$', '0', '$\pi/4$', '$\pi/2$']
    ax.set_longitude_grid_ends(90)
    plt.grid(alpha=0.7)
    ax.set_title(f"Spin {j+1} Marginal")

plt.colorbar(c, ax=ax, location="bottom")
plt.show()

In [None]:
# Optimize LHV model for a qutrit state defined by its density matrix (runtime ~ 8,5 minutes on a single NVIDIA Quadro RTX 6000 GPU)

# General Setup

measurement_type = "PVM"    # PVM / POVM measurements
d = 3                       # qudit dimension
D = 1                       # maximal monomial degree
symmetric = True            # Use permutation invariant measurement rule or not
ONB = True                  # Orthonormalize monomials or not
coeffs = None               # Coefficients for ONB from monomials (if previously saved)
N_MC = 10**5                # Number of Monte Carlo samples for L2 inner product for Gram-Schmidt orthonormalization

match measurement_type:
    case "PVM":
        delta = d # number of measurement outcomes
        n_vars = (d-1) * (d**2-1) # number of variables specifying a measurement
        sample = QDT.sample_PVMs # sample PVM measurements
        params_extractor = QDT.gell_mann_params_PVM # extract Gell-Mann parameters from measurement operators
        PQM = QDT.PQM_pvm # QM measurement rule
    case "POVM":
        delta = d**2
        n_vars = (d**2-1) * d**2
        sample = QDT.sample_POVMs
        params_extractor = QDT.gell_mann_params_POVM
        PQM = QDT.PQM_povm
    case _:
        print("Measurement type needs to be 'PVM' or 'POVM'")

if ONB and (coeffs is None):
    key, key_M = random.split(key)
    measurements = jnp.squeeze(sample(key_M, N_MC, 1, d)) # Monte Carlo Samples
else:
    measurements = None

# Measurement rule, hidden-variable dimension and matrix of coefficients if monomials are orthonormalized
PLHV, hidden_dim, coeffs = QDT.LHV_rule_constructor(delta, n_vars, D, params_extractor, symmetric=symmetric, ONB=ONB, coeffs=coeffs, samples=measurements, alpha=1e-4, beta=1e-3)


# Hyper parameters
N_hidden = 2**12   # hidden-variable cloud size
N_steps = int(1e5)
learning_rate = N_hidden * 3e-6
N_measures = 2**9
N_measures_test = 2**12
N_test_runs = 2**4

# Isotropic state
N_particles = 2 # 2-qutrit state
visibility = 0.1
i = complex(0., 1.)
basis = jnp.eye(d)
psi = sum([reduce(jnp.kron, N_particles*[basis[j]]) for j in range(d)]) / jnp.sqrt(d) 
rho = visibility*jnp.outer(psi, psi) + (1.-visibility)*jnp.eye(d**N_particles)/d**N_particles

# keys
key, key_init, key_train = random.split(key, (3,))

# loss function
deviation = QDT.distance_KL

N_steps1 = int(1e3) # linearly annihilate learning rate starting after N_step1 steps
schedule = optax.schedules.linear_schedule(learning_rate, 0., N_steps-N_steps1, transition_begin=N_steps1)
optimizer = optax.adam(learning_rate=schedule)
hyper_params = [d, N_measures, N_measures_test, N_steps] 
functions = [PLHV, PQM, deviation, sample, optimizer]

# Init hidden variables
init = random.normal(key_init, (N_hidden, N_particles, delta, hidden_dim))

print("Start optimizing LHV")
# optimize
cloud, loss_values = QDT.autoLHV(key_train, init, rho, *hyper_params, *functions)

print("Start evaluating LHV")
# test loss
test_losses = np.zeros(N_test_runs)
for jdx in range(N_test_runs):
    key, key_test = random.split(key)
    test_losses[jdx] = QDT.eval_test_loss(key_test, cloud, rho, d, N_measures_test, *(functions[:-1]))
test_loss = np.mean(test_losses)

print("final loss: {:.2e}".format(test_loss))

In [None]:
# Plot the loss progression
dpi=150
size=15
stpsize = 10 # Only plot the loss every stpsize steps
plt.figure(figsize=(6, 3), dpi = dpi)
plt.plot(range(stpsize-1, N_steps, stpsize), loss_values[stpsize-1::stpsize], lw=.5, color="black")
plt.ylabel(r"Loss", size=size)
plt.xlabel(r"Steps", size=size)
plt.grid(alpha=0.5, zorder=-1)
plt.xscale("log")
plt.xlim((stpsize-1, 1.5*N_steps))
#plt.ylim((2e-8, 1e-7))
plt.yscale("log")
plt.show()