In [1]:
### IMPORTS ###

# Quantum libraries:
import pennylane as qml
#from pennylane 
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit
from functools import partial

# Plotting
from matplotlib import pyplot as plt
import plotly

# Other
import sys, os
import time
import copy
from tqdm.notebook import tqdm # Pretty progress bars
import joblib # Writing and loading
from noisyopt import minimizeSPSA
import optuna # Automatic tuning tool
import multiprocessing

import warnings
warnings.filterwarnings("ignore", message="For Hamiltonians, the eigenvalues will be computed numerically. This may be computationally intensive for a large number of wires.Consider using a sparse representation of the Hamiltonian with qml.SparseHamiltonian.")
##############


In [2]:
# Global flag to set a specific platform, must be used at startup.
jax.config.update('jax_platform_name', 'cpu')

In [3]:
# My functions:
sys.path.insert(0, '../')
import vqe_functions as vqe
import qcnn_functions as qcnn


In [4]:
## Load data:
data = joblib.load('../vqe_states_job/vqe_params_0noise.job')

print('Size of Data Set: {0}'.format(len(data)))

train_index = np.sort(np.random.choice(np.arange(len(data)), size=int(0.8*len(data)), replace=False ))

X_train, Y_train = [], []
X_test, Y_test   = [], []

for i in range(len(data)): 
    if i in train_index:
        X_train.append(data[i][0])
        Y_train.append(data[i][1])
    else:
        X_test.append(data[i][0])
        Y_test.append(data[i][1])
        
X_train, Y_train = jnp.array(X_train), jnp.array(Y_train)
X_test, Y_test   = jnp.array(X_test), jnp.array(Y_test)

print('______________________________')
print('Size of Training Set: {0}'.format(np.shape(X_train)[0]))
print('Size of Test Set    : {0}'.format(np.shape(X_test)[0]))


Size of Data Set: 100
______________________________
Size of Training Set: 80
Size of Test Set    : 20


In [5]:
N = 4
J = 1

vqe_circuit_fun = vqe.vqe_circuit
qcnn_circuit_fun = qcnn.qcnn_circuit

In [6]:
device = qml.device("default.qubit.jax", wires = N, shots = None)

* 1. Split Test and Training set as 20-80 (This was not done randomly, the split was done so the test points are as sparse (and different) as possible)
* 2. We choose different sizes of the training set [npoints_list]
* 3. For npoints in npoints_list:
    * Draw randomly a subset of npoints from the training set
    * Train the QCNN just with those samples
    * Get the Loss and Accuracy on the Test Set
    * Get the mean and variance for the QCNNs trained with npoitns

In [7]:
# Training function
def train_jax(epochs, lr, r_shift, N, device, vqe_circuit_fun, qcnn_circuit_fun,
              X_train, Y_train, X_test = [], Y_test = [], plot = True, info = True, batch_size = 32):
    
    X_train, Y_train = jnp.array(X_train), jnp.array(Y_train)
    X_test, Y_test = jnp.array(X_test), jnp.array(Y_test)
    
    @qml.qnode(device, interface="jax", diff_method=None)
    def qcnn_circuit_prob(params_vqe, params, N):
        qcnn_circuit_fun(params_vqe, vqe_circuit_fun, params, N, 0, 0, 0, 0)

        return qml.probs(wires = N - 1)
    
    if info:
        print('+-- PARAMETERS ---+')
        print('a factor   = {0} (\'a\' coefficient of the optimizer)'.format(lr) )
        print('r_shift    = {0} (c coefficient of the optimizer)'.format(r_shift) )
        print('epochs     = {0} (# epochs for learning)'.format(epochs) )
        print('N          = {0} (Number of spins of the system)'.format(N) )
        print('batch_size = {0} (batch size of the training process)'.format(batch_size) )
    
    # Initialize parameters
    n_params = qcnn_circuit_fun([0]*1000, vqe_circuit_fun, [0]*1000, N)
    params = [np.pi/4]*n_params
    
    # Cost function to minimize, returning the cross-entropy of the training set
    # Additionally it computes the accuracy of the training and test set 
    # (every 10 epochs)
    
    def update(params, seed = 0):
        np.random.seed(seed=seed)
        if batch_size == 0:
            sub_train_idx = np.arange(len(X_train))
        else:
            sub_train_idx = np.random.choice(np.arange(len(X_train)), batch_size, replace = False)
            
        X_train_sub = jnp.array(X_train[sub_train_idx])
        Y_train_sub = jnp.array(Y_train[sub_train_idx])
        
        wrapper_circuit = lambda vqe: qcnn_circuit_prob(vqe, params, N)
        vcircuit = jax.vmap(wrapper_circuit)
        predictions = vcircuit(X_train_sub)
        
        cross_entropy = - np.sum( np.log(predictions[np.where(np.equal(Y_train_sub,1) ),1] )  ) - np.sum(np.log( 1 - predictions[np.where(np.equal(Y_train_sub,0) ),1] ) )
            
        return cross_entropy
    
    def callback(params):
        wrapper_circuit = lambda vqe: qcnn_circuit_prob(vqe, params, N)
        vcircuit = jax.vmap(wrapper_circuit)
        predictions = vcircuit(X_train)

        cross_entropy = - np.sum( np.log(predictions[np.where(np.equal(Y_train,1) ),1] )  ) - np.sum(np.log( 1 - predictions[np.where(np.equal(Y_train,0) ),1] ) )

        accuracy_history.append( 100*np.sum(np.argmax(predictions, axis=1) == Y_train)/len(Y_train) )
        loss_history.append( cross_entropy )

        if len(Y_test) > 0:
            if len(accuracy_history)%10 == 0:
                predictions = vcircuit(X_test)

                cross_entropy = - np.sum( np.log(predictions[np.where(np.equal(Y_test,1) ),1] )  ) - np.sum(np.log( 1 - predictions[np.where(np.equal(Y_test,0) ),1] ) )

                accuracy_history_test.append( 100*np.sum(np.argmax(predictions, axis=1) == Y_test)/len(Y_test) )
                loss_history_test.append( cross_entropy )
        
        if info:
            pbar.update(1)
            pbar.set_description('Cost: {0} | Accuracy: {1}'.format(np.round(loss_history[-1],5), np.round(accuracy_history[-1],2) )  )
        
    loss_history = []
    accuracy_history = []
    loss_history_test = []
    accuracy_history_test = []
    
    #with tqdm(total=epochs) as pbar:
    if info:
        pbar = tqdm.tqdm(total = epochs, position=0, leave=True)
    else:
        pbar = False
    
    res = minimizeSPSA(update,
                       x0=params,
                       niter=epochs,
                       paired=True,
                       c=r_shift,
                       a=lr,
                       callback = callback)
    
    # Update final parameterss
    params = res.x
    
    if plot:
        plt.figure(figsize=(15,5))
        plt.plot(np.arange(len(loss_history)), np.asarray(loss_history), label = 'Training Loss')
       #if len(X_test) > 0:
            #plt.plot(np.arange(steps), np.asarray(loss_history_test)/len(X_test), color = 'green', label = 'Test Loss')
        plt.axhline(y=0, color='r', linestyle='--')
        plt.title('Loss history')
        plt.ylabel('Average Cross entropy')
        plt.xlabel('Epoch')
        plt.grid(True)
        plt.legend()
        
        plt.figure(figsize=(15,4))
        plt.plot(np.arange(len(accuracy_history)), accuracy_history, color='orange', label = 'Training Accuracy')
        if len(X_test) > 0:
            plt.plot(np.arange(len(accuracy_history_test))*10, accuracy_history_test, color='violet', label = 'Test Accuracy')
        plt.axhline(y=100, color='r', linestyle='--')
        plt.title('Accuracy')
        plt.ylabel('%')
        plt.xlabel('Epoch')
        plt.grid(True)
        plt.legend()
        
    return loss_history, accuracy_history, params

In [8]:
def reduced_npoints_accuracies(X_train, Y_train, X_test, Y_test, npoints_list, n_iters, epochs, device, vqe_circuit, qcnn_circuit,
                               testdata_ratio = 0.2, plot = True):
    '''
    From the same VQE parameters we want to train many times a QCNN with different samples
    in order to find what is the average accuracy n-training points can reach
    '''
    # We need to declare this wrapper as global otherwise multiprocessing
    # does not work
    global wrapped_update
    
    @qml.qnode(device, interface="jax", diff_method=None)
    def qcnn_circuit_prob(params_vqe, params, N, vqe_circuit, qcnn_circuit):
        qcnn_circuit(params_vqe, vqe_circuit, params, N)
    
        return qml.probs(wires = N - 1)
    
    loss_means = []
    loss_devs  = []
    
    acc_means = []
    acc_devs  = []
    
    pbar = tqdm(total = len(npoints_list))
    
    for npoints in npoints_list:
        # Function for training a QCNN with multiprocessing
        def wrapped_update(idx):
            # Train a QCNN and save loss/accuracy for data-analysis
            
            # Force randomicity, multiprocessing tends to give the same random vectors
            # to each parallelized process
            np.random.seed(int.from_bytes(os.urandom(4), byteorder='little'))
            
            # Choose a random subset of the training set of size npoints
            train_idx_it = np.random.choice(np.arange(len(Y_train)), npoints, replace=False)
            X_train_it = X_train[train_idx_it]
            Y_train_it = Y_train[train_idx_it]
            
            # Train a QCNN with the random subset of the training set and get the parameters 
            _, _, params = train_jax(epochs, 0.05, 0.0008, N, device, vqe_circuit_fun, qcnn_circuit_fun,
                                     X_train_it, Y_train_it, plot = False, info = False, batch_size = 0)
            
            # Compute Accuracy and Loss on the test-set
            wrapper_circuit = lambda vqe: qcnn_circuit_prob(vqe, params, N, vqe_circuit, qcnn_circuit)
            vcircuit = jax.vmap(wrapper_circuit)
            predictions = vcircuit(X_test)
                
            cross_entropy = - np.sum( np.log(predictions[np.where(np.equal(Y_test,1) ),1] )  ) - np.sum(np.log( 1 - predictions[np.where(np.equal(Y_test,0) ),1] ) )
            accuracy = 100*np.sum(np.argmax(predictions, axis=1) == Y_test)/len(Y_test)
            
            return cross_entropy, accuracy
        
        p = multiprocessing.Pool()
        with p: rdata = p.map(wrapped_update, np.arange(n_iters))
        
        rdata = np.array(rdata)
        
        loss_means.append(np.mean(rdata[:,0]))
        loss_devs.append(np.std(rdata[:,0]))
        acc_means.append(np.mean(rdata[:,1]))
        acc_devs.append(np.std(rdata[:,1]))
        
        pbar.update(1)
        
    if plot:
        fig, ax = plt.subplots(2, 1, figsize=(10,10))
            
        ax[0].plot(np.arange(len(npoints_list)), loss_means, color='indigo', lw = 2, alpha = 0.3)
        ax[0].errorbar(np.arange(len(npoints_list)), loss_means, yerr=3*np.array(loss_devs), fmt='o', color='indigo',
                       ecolor='blueviolet', elinewidth=3, capsize=0)
        ax[0].set_xticks(np.arange(len(npoints_list)))
        ax[0].set_xticklabels(npoints_list)
        ax[0].grid(True)
        ax[0].set_title('Losses on Test Set'.format(N,J))
        ax[0].set_xlabel('# points in Training Set')
        ax[0].set_ylabel('Cross-entropy')

        ax[1].plot(np.arange(len(npoints_list)), acc_means, color='red', ms = 7, alpha = 0.3)
        ax[1].errorbar(np.arange(len(npoints_list)), acc_means, yerr=3*np.array(acc_devs),fmt='o', color='red',
                       ecolor='red', elinewidth=3, capsize=0)
        ax[1].set_xticks(np.arange(len(npoints_list)))
        ax[1].set_xticklabels(npoints_list)
        ax[1].grid(True)
        ax[1].set_title('Accuracies on Test Set')
        ax[1].set_xlabel('# points in Training Set')
        ax[1].set_ylabel('(%)')
        
    return loss_means, loss_devs, acc_means, acc_devs

In [9]:
loss_means, loss_devs, acc_means, acc_devs = reduced_npoints_accuracies(X_train, Y_train, X_test, Y_test, [1, 10, 25, 50, 60, 70, 80], 100, 150,
                                                                        device, vqe.vqe_circuit, qcnn.qcnn_circuit, 
                                                                        testdata_ratio = 0.2, plot = True)

  0%|          | 0/7 [00:00<?, ?it/s]

Process ForkPoolWorker-27:
Process ForkPoolWorker-22:
Process ForkPoolWorker-24:
Process ForkPoolWorker-23:
Process ForkPoolWorker-32:
Process ForkPoolWorker-19:
Process ForkPoolWorker-25:
Process ForkPoolWorker-26:
Process ForkPoolWorker-28:
Process ForkPoolWorker-21:
Process ForkPoolWorker-31:


KeyboardInterrupt: 

In [None]:
joblib.dump((loss_means, loss_devs), '../../data/small_sets/loss.job')
joblib.dump((acc_means, acc_devs), '../../data/small_sets/acc.job')