In [None]:
# Packages
import os
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION']='.99' # Provide as much GPU memory as possible

import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax import random

# Import functions
import QM_Spin_Systems as QSS
import LHV_Training_Setup as LHV
import Training as TRG
import Plotting as PLT

pi = jnp.pi

In [None]:
# Optimize LHV model for single Werner state (should run in a few seconds on a single GPU)

N_spins = 2 # Number of spins, 2 for 2-spin Werner states

# Measurement rules
D = 5
PLHV, hidden_dim = LHV.PLHV_sh(D) # Alternatives: PLHV_Bell = PLHV_sh(1), PLHV_sh3 = PLHV_sh(3), PLHV_sh5 = PLHV_sh(5), 
                                  #               PLHV_sh5_old (different normalization), PLHV_spherical_harmonics_planar (PLHV_sh5 w/ z=0)
                                  #               PLHV_polynomial(D): Odd polynomial expansion up to degree D
PQM = LHV.PQM_Werner # For Werner states, state=visibility
                     # Alternatives: PQM (state given by its correlation matrix), 
                     #               PQM_XYZ (states with symmetries of "XYZ Hamiltonian" represented by correlators (xx, yy, zz))

# State
visibility = .5

# Setup keys for random number generator
key = random.PRNGKey(0)
key, key_init, key_train = random.split(key, (3, ))

# Hyper parameters
N_hidden = 2**14 # Hidden state cloud size
N_measures = 2**8 # Batch size (number of tuples of measurement directions sampled per gradient descent step)
N_measures_test = 2**14 # Batch size for final evaluation of the optimized cloud
N_steps = 1*10**4 # Number of gradient descent steps
N_steps_ft = 10**3 # Number of gradient descent steps with 10x smaller learning rate for fine tuning at the end of training
learning_rate = N_hidden * 2e-5 # The learning rate should be proportional to the hidden state cloud size
hyper_params = [N_measures, N_measures_test, N_steps, N_steps_ft, learning_rate]

functions = [PLHV, PQM, LHV.KL, LHV.sample3Dprojective] # Use the KL divergence as a distance measure and sample from all projective measurements
# Alternatives for distance measure: L2 (squared difference), L1 (difference, not suitable for training)
# Alternatives for sampling method: sample2Dprojective (only sample in xy plane), in general needs to be compatible with PQM

# Init and optimize the hidden state cloud
init = random.normal(key_init, (N_hidden, N_spins, hidden_dim)) # Independent gaussian initialization of all hidden state components in the cloud
# Obtain optimized hidden state cloud, the progression of the loss during training and the final loss
hidden_states, loss_values, test_loss = TRG.autoLHV(key_train, init, visibility, *hyper_params, *functions)

print("Final loss: {:.2e}".format(test_loss))

# Test mean difference of measurement probabilities ( = L1 loss)
key, key_test = random.split(key)
N_measures_test = 2**14
test_functions = [PLHV, PQM, LHV.L1, LHV.sample3Dprojective]
test_loss_L1 = TRG.test_loss(key_test, hidden_states, visibility, N_measures_test, *test_functions)
print("Mean deviation: {:.2e}".format(test_loss_L1))

In [None]:
# Plot the loss progression
dpi=150
size=15
stpsize = 10 # Only plot the loss every stpsize steps
plt.figure(figsize=(6, 3), dpi = dpi)
plt.plot(range(stpsize-1, N_steps, stpsize), loss_values[stpsize-1::stpsize], lw=.5, color="black")
plt.ylabel(r"Loss", size=size)
plt.xlabel(r"Steps", size=size)
plt.grid(alpha=0.5, zorder=-1)
plt.xscale("log")
plt.xlim((stpsize-1, 1.5*N_steps))
#plt.ylim((2e-8, 1e-7))
plt.yscale("log")
plt.show()

In [None]:
# Optimize LHV model for an arbitrary state defined by its correlations

N_spins = 1

# Measurement rules
PLHV, hidden_dim = LHV.PLHV_Bell() # = LHV.PLHV_sh(1)
PQM = LHV.PQM # requires a state defined by its correlation matrix

up = jnp.array([1., 0.])
rho = jnp.outer(jnp.conjugate(up), up) # density matrix for the state "up"
corrs = QSS.correlators(rho, N_spins) # correlation matrix

key = random.PRNGKey(0)
key, key_init, key_train = random.split(key, (3, ))

# Hyper parameters
N_hidden = 2**10 # smaller N_h less accurate by nicer for visualization below
N_measures = 2**12
N_measures_test = 2**14
N_steps = 3*10**4
N_steps_ft = 10**3
learning_rate = N_hidden * 2e-5
hyper_params = [N_measures, N_measures_test, N_steps, N_steps_ft, learning_rate]
functions = [PLHV, PQM, LHV.KL, LHV.sample3Dprojective] # Use the squared difference for the loss function

# Init and optimize
init = random.normal(key_init, (N_hidden, N_spins, hidden_dim))
hidden_states, loss_values, test_loss = TRG.autoLHV(key_train, init, corrs, *hyper_params, *functions)

print("Final loss: {:.2e}".format(test_loss))

In [None]:
# Plotting including the hidden state distribution (first three components normalized to the surface of a sphere)
import matplotlib

dpi = 150
size = 15
stpsize = 10

plt.figure(figsize=(5, 3*(N_spins)+3), dpi=dpi, constrained_layout=True)

ax00 = plt.subplot(N_spins+1, 1, 1)
ax00.plot(range(stpsize-1, N_steps, stpsize), loss_values[stpsize-1:N_steps:stpsize], lw=.5, color="black")
plt.title(r"Final loss $= {:.2e}$".format(test_loss), size=size-5)
plt.xlabel(r"Steps", size=size-5)
plt.grid(alpha=0.5, zorder=-1)
plt.xscale("log")
plt.yscale("log")

N_grid = 100
phi_grid = jnp.linspace(-pi, pi, 2*N_grid)
theta_grid = jnp.linspace(-pi/2., pi/2., N_grid)
phi_grid, theta_grid = jnp.meshgrid(phi_grid, theta_grid)
normal_vectors = PLT.normal(phi_grid, theta_grid)
delta = 0.02

for j in range(N_spins):

    spin = hidden_states[:, j, :3]
    spin /= jnp.linalg.norm(spin, axis=-1, keepdims=True)

    phi_grid_scatter, theta_grid_scatter = PLT.S2_angles(spin)

    facecolors = PLT.gaussian_blurr(spin, normal_vectors, delta)
    facecolors /= jnp.max(facecolors)
    minimum = 0. # jnp.min(facecolors)
    maximum = 1. # jnp.max(facecolors)
    cnorm = matplotlib.colors.Normalize(vmin = minimum, vmax = maximum, clip=False)
    facecolors = matplotlib.cm.plasma_r(cnorm(facecolors))

    ax = plt.subplot(N_spins+1, 1, j+2, projection='mollweide')
    c = ax.pcolormesh(phi_grid, theta_grid, facecolors, cmap="plasma_r")
    ax.scatter(phi_grid_scatter, theta_grid_scatter, s=1, color="black", marker="o", alpha=1)
    ax.set_xticks([-pi, -3.*pi/4., -pi/2, -pi/4., 0., pi/4., pi/2., 3.*pi/4., pi])
    ax.set_yticks([-0.99*pi/2., -pi/4., 0., pi/4., 0.99*pi/2.])
    ax.set_xticklabels(['', '', '', '', '', '', '', '', '']) #['$-\pi$', '$-\pi/2$', '0', '$\pi/2$', '$\pi$']
    ax.set_yticklabels(['', '', '', '', '']) # ['$-\pi/2$', '$-\pi/4$', '0', '$\pi/4$', '$\pi/2$']
    ax.set_longitude_grid_ends(90)
    plt.grid(alpha=0.7)
    ax.set_title(f"Spin {j+1} Marginal")

plt.colorbar(c, ax=ax, location="bottom")
plt.show()