In [1]:
### IMPORTS ###
# Quantum libraries:
import pennylane as qml
from pennylane import numpy as np
import jax
import jax.numpy as jnp
from jax import jit
from functools import partial

# Plotting
from matplotlib import pyplot as plt
import plotly

# Other
import os, sys
import time

import copy
import tqdm # Pretty progress bars
from IPython.display import Markdown, display # Better prints
import joblib # Writing and loading
from noisyopt import minimizeSPSA
import optuna # Automatic tuning tool

import multiprocessing

import warnings
warnings.filterwarnings("ignore", message="For Hamiltonians, the eigenvalues will be computed numerically. This may be computationally intensive for a large number of wires.Consider using a sparse representation of the Hamiltonian with qml.SparseHamiltonian.")

##############

# My functions:
sys.path.insert(0, '../')
import vqe_functions as vqe
import qcnn_functions as qcnn


In [1]:
# Global flag to set a specific platform, must be used at startup.
jax.config.update('jax_platform_name', 'cpu')


NameError: name 'jax' is not defined

In [4]:
N = 4
for spin in range(N-1,0,-1):
    print(spin)

3
2
1


In [2]:
N = 4
vqe_circuit_fun = vqe.vqe_circuit
qcnn_circuit_fun= qcnn.qcnn_circuit

In [3]:
device_jax = qml.device("default.qubit.jax", wires = N, shots = None)
device_mix = qml.device("default.mixed", wires = N, shots = None)
device_def = qml.device("default.qubit", wires = N, shots = None)

In [4]:
## Load data:
data = joblib.load('../vqe_states_job/vqe_params_0noise.job')

print('Size of Data Set: {0}'.format(len(data)))

#train_index = np.concatenate(( np.random.choice(np.arange(25), 15, replace = False),np.random.choice(np.arange(75,100), 15, replace = False) ))
#test_index  = np.arange(25,75)
train_index = np.random.choice(np.arange(len(data)), int(0.8*len(data)), replace = False)
test_index = []

X_train, Y_train = [], []
X_test, Y_test   = [], []

for i in range(len(data)): 
    if i in train_index:
        X_train.append(data[i][0])
        Y_train.append(data[i][1])
    else:
        test_index.append(i)
        X_test.append(data[i][0])
        Y_test.append(data[i][1])
        
X_train, Y_train = jnp.array(X_train), jnp.array(Y_train)
#X_test, Y_test   = jnp.array(X_test), jnp.array(Y_test)

print('______________________________')
print('Size of Training Set: {0}'.format(np.shape(X_train)[0]))
print('Size of Test Set    : {0}'.format(np.shape(X_test)[0]))


Size of Data Set: 100
______________________________
Size of Training Set: 80
Size of Test Set    : 20


In [5]:
jaxgpu_start = time.time()
_ = qcnn.train_jax(100, 0.005, 0.0001, N, device_jax, vqe_circuit_fun, qcnn_circuit_fun, 
          X_train, Y_train, X_test, Y_test, plot = False, info = True, batch_size = 0)
jaxgpu_stop  = time.time()

+-- PARAMETERS ---+
a factor   = 0.005 ('a' coefficient of the optimizer)
r_shift    = 0.0001 (c coefficient of the optimizer)
epochs     = 100 (# epochs for learning)
N          = 4 (Number of spins of the system)
batch_size = 0 (batch size of the training process)


Cost: 40.47184 | Accuracy: 76.25: 100%|███████| 100/100 [01:46<00:00,  1.06s/it]


In [6]:
# Global flag to set a specific platform, must be used at startup.
jax.config.update('jax_platform_name', 'cpu')

jaxcpu_start = time.time()
_ = qcnn.train_jax(100, 0.005, 0.0001, N, device_jax, vqe_circuit_fun, qcnn_circuit_fun, 
          X_train, Y_train, X_test, Y_test, plot = False, info = True, batch_size = 0)
jaxcpu_stop  = time.time()

+-- PARAMETERS ---+
a factor   = 0.005 ('a' coefficient of the optimizer)
r_shift    = 0.0001 (c coefficient of the optimizer)
epochs     = 100 (# epochs for learning)
N          = 4 (Number of spins of the system)
batch_size = 0 (batch size of the training process)


Cost: 41.69527 | Accuracy: 76.25: 100%|███████| 100/100 [01:38<00:00,  1.02it/s]


In [7]:
# Default.mixed (multiprocessed) 
mixed_start = time.time()
_ = qcnn.train(100, 0.005, 0.0001, N, device_mix, vqe_circuit_fun, qcnn_circuit_fun, 0, 0, 0, 0,
          X_train, Y_train, X_test, Y_test, plot = False, info = True, batch_size = 0)
mixed_stop = time.time()


+-- PARAMETERS ---+
a factor   = 0.005 ('a' coefficient of the optimizer)
r_shift    = 0.0001 (c coefficient of the optimizer)
epochs     = 100 (# epochs for learning)
N          = 4 (Number of spins of the system)
batch_size = 0 (batch size of the training process)


Cost: 41.69386 | Accuracy: 76.25: 100%|███████| 100/100 [01:51<00:00,  1.11s/it]


In [8]:
# Default.mixed (multiprocessed) 
def_start = time.time()
_ = qcnn.train(100, 0.005, 0.0001, N, device_def, vqe_circuit_fun, qcnn_circuit_fun, 0, 0, 0, 0,
          X_train, Y_train, X_test, Y_test, plot = False, info = True, batch_size = 0)
def_stop = time.time()


+-- PARAMETERS ---+
a factor   = 0.005 ('a' coefficient of the optimizer)
r_shift    = 0.0001 (c coefficient of the optimizer)
epochs     = 100 (# epochs for learning)
N          = 4 (Number of spins of the system)
batch_size = 0 (batch size of the training process)


Cost: 41.69386 | Accuracy: 76.25: 100%|███████| 100/100 [02:18<00:00,  1.39s/it]


In [9]:
print('RESULTS')
devs = ['Jax GPU', 'Jax CPU', 'Mixed  ', 'Default']
res = [jaxgpu_stop-jaxgpu_start, jaxcpu_stop-jaxcpu_start, mixed_stop-mixed_start, def_stop-def_start]
resdevs = zip(res, devs)
sorted_resdevs = sorted(resdevs)
for pos, (re, devs) in enumerate(sorted_devres):
    print(pos,':', dev, '   ', re, 's')

RESULTS
0 : Default     138.7549843788147 s
1 : Jax CPU     98.51372075080872 s
2 : Jax GPU     106.19379806518555 s
3 : Mixed       111.19150280952454 s
