In [1]:
import pennylane as qml
import jax
from jax import numpy as jnp
import optax
import circuit_lib as cl
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
print(jax.devices())
jax.config.update('jax_enable_x64', True) 

[CudaDevice(id=0)]


In [2]:
df = pd.read_csv("apple_quality_balanced_dataset.csv", header = 0)

df.loc[df["Quality"] == "good", "Quality"] = 1
df.loc[df["Quality"] == "bad", "Quality"] = 0
df = df.drop(columns=["A_id"])

scaler = MinMaxScaler(feature_range=(0,1))
scaled_x = scaler.fit_transform(df.iloc[:,:7])

train = scaled_x[:850]
test = scaled_x[850:]

x_train = jnp.array(train)
y_train = jnp.array(df.iloc[:850,7].to_numpy(dtype= np.float32))

x_test = jnp.array(test)
y_test = jnp.array(df.iloc[850:,7].to_numpy(dtype= np.float32))

In [3]:
dev = qml.device("default.qubit", wires = 7)

ZZFeatureMap = cl.create_ZZFeatureMap(dev)

@qml.qnode(dev, interface="jax")
def EmbeddingCircuit(x):
    ZZFeatureMap(x)
    return qml.state()


@jax.jit
def embeddor(x):
    return EmbeddingCircuit(x)

embedding_map = jax.vmap(embeddor,[0])

x_train_embedding = embedding_map(x_train)
x_test_embedding = embedding_map(x_test)

In [4]:
num_reps = 6
entanglement = cl.SCA 

In [32]:
observable = qml.numpy.asarray([[0,0],
                        [0,1]])

observable = qml.numpy.array([
    [1, 0, 0, 0, 0, 0, 0, 0],
    [0, 1, 0, 0, 0, 0, 0, 0],
    [0, 0, 1, 0, 0, 0, 0, 0],
    [0, 0, 0, 1, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0]
])
TwoLocal = cl.create_TwoLocal(dev,num_reps,entanglement)
@qml.qnode(dev, interface = "jax")
def QuantumCircuit(params, embedding):
    qml.StatePrep(embedding, wires = [0,1,2,3,4,5,6], normalize= True)
    TwoLocal(params)
    for i in range(7):
        qml.Hadamard(i)
    return qml.expval(qml.Hermitian(observable, wires=[0,1,2]))


In [33]:
@jax.jit
def mse(params, data, target):
    return (target - QuantumCircuit(params, data))**2

mse_map = jax.vmap(mse, (None, 0, 0))

@jax.jit
def loss_fn(params, data, target):
    return jnp.mean(mse_map(params, data, target))

In [34]:
learning_rate_schedule = optax.schedules.join_schedules(schedules = [optax.constant_schedule((0.75)**i) for i in range(6)],
                                                                     boundaries = [100, 250, 500, 700, 900])
opt = optax.adam(1e20)
max_steps = 1000
@jax.jit
def optimiser(params, data, targets, print_training):
    param_history = jnp.zeros(((max_steps + 1),) + params.shape)  
    grad_history = jnp.zeros(((max_steps + 1),) + params.shape)  
    opt_state = opt.init(params)
    #Packages the arguments to be sent to the function update_step_jit 
    args = (params, opt_state, jnp.asarray(data), targets, print_training, param_history, grad_history)
    #Loops max_steps number of times
    (params, opt_state, _, _, _, param_history, grad_history) = jax.lax.fori_loop(0, max_steps+1, update_step_jit, args) 
    return params, param_history, grad_history

@jax.jit
def update_step_jit(i,args):
    # Unpacks the arguments
    params, opt_state, data, targets, print_training, param_history, grad_history = args
    param_history = param_history.at[i].set(params)
    # Gets the loss and the gradients to be applied to the parameters, by passing in the loss function and the parameters, to see how the parameters perform 
    loss_val, grads = jax.value_and_grad(loss_fn)(params, data, targets)
    grad_history = grad_history.at[i].set(grads)
    #Prints the loss every 25 steps if print_training is enable
    def print_fn():
        jax.debug.print("Step: {i}  Loss: {loss_val}", i=i, loss_val=loss_val)
    jax.lax.cond((jnp.mod(i, 50) == 0 ) & print_training, print_fn, lambda: None)
    #Applies the param updates and updates the optimiser states
    updates, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    #Returns the arguments to be resupplied in the next iteration
    return (params, opt_state, data, targets, print_training, param_history, grad_history)   


In [35]:
init_params = jnp.array(np.random.default_rng().random(size=(num_reps + 1, 7))
)
opt_params, params_hist, grad_hist = optimiser(init_params, x_train_embedding, y_train, True)
print(loss_fn(opt_params, x_test_embedding, y_test))



Step: 0  Loss: 0.24242673928422356
Step: 50  Loss: 0.24998667709954
Step: 100  Loss: 0.25232226035008376
Step: 150  Loss: 0.25170440099981073
Step: 200  Loss: 0.24368691560392197
Step: 250  Loss: 0.2562264973985301
Step: 300  Loss: 0.2514164339285035
Step: 350  Loss: 0.2547674236774299
Step: 400  Loss: 0.250229512463104
Step: 450  Loss: 0.2518204284793223
Step: 500  Loss: 0.25601424784704446
Step: 550  Loss: 0.2524158902377224
Step: 600  Loss: 0.24671911549767178
Step: 650  Loss: 0.24932116880618785
Step: 700  Loss: 0.24393344715776835
Step: 750  Loss: 0.24576811531342727
Step: 800  Loss: 0.26021793320784764
Step: 850  Loss: 0.26827953128418597
Step: 900  Loss: 0.24937669030962323
Step: 950  Loss: 0.25023331993601255
Step: 1000  Loss: 0.25962754204111194
0.25345422523992844


In [23]:
print(params_hist[0])
print(params_hist[500])
print(params_hist[0].shape)

[[0.83447189 0.73620322 0.94334843 0.69785405 0.42409798 0.3311938
  0.11164608]
 [0.40849868 0.34635488 0.12168918 0.61002939 0.37040859 0.43267123
  0.90275899]
 [0.32874892 0.78818826 0.16675669 0.85882916 0.68963156 0.49743242
  0.18986433]
 [0.9298198  0.80092225 0.30607966 0.03111455 0.5737384  0.31792827
  0.10503104]
 [0.67123105 0.58436243 0.45929824 0.85950728 0.33422571 0.05206273
  0.91733703]
 [0.28639464 0.78547651 0.97777961 0.78469062 0.39450294 0.54850348
  0.31256419]
 [0.04738927 0.1742703  0.76815272 0.5933384  0.12484139 0.7514961
  0.36935942]]
[[ 20.99953229   9.13853645  -0.84791714   8.31992197   6.11225803
    7.98665302  22.35077907]
 [  3.44577929   2.06278934  18.68668613  20.92378007 -14.31052478
    7.01024174   5.96652542]
 [ 28.70170463   6.34442725  20.68834095  22.59107106  26.5337786
   46.434774     2.6596951 ]
 [ -1.4121484   38.74704012  32.08126515   6.11222427  18.2821172
   33.53854054   9.91112895]
 [ -4.39534485   3.51724037  15.53165874  -2.

In [27]:
print(grad_hist[0])
print(grad_hist[4])

[[-4.33680869e-19  1.08420217e-18  6.50521303e-19 -1.95156391e-18
  -3.25260652e-19  8.67361738e-19  0.00000000e+00]
 [ 1.73472348e-18  5.14996032e-19 -2.16840434e-19  8.67361738e-19
   3.25260652e-19  2.71050543e-19  4.33680869e-19]
 [ 9.75781955e-19  1.08420217e-18 -4.33680869e-19  1.51788304e-18
   8.67361738e-19 -2.16840434e-19  6.50521303e-19]
 [-6.50521303e-19  0.00000000e+00 -1.08420217e-18  1.73472348e-18
   0.00000000e+00  1.51788304e-18  8.67361738e-19]
 [-4.33680869e-19 -4.33680869e-19  4.33680869e-19  6.50521303e-19
   2.16840434e-18  6.50521303e-19  4.33680869e-19]
 [ 6.50521303e-19 -2.16840434e-19 -1.95156391e-18  4.33680869e-19
   1.08420217e-18  4.33680869e-19 -2.16840434e-19]
 [ 0.00000000e+00  4.33680869e-19 -2.16840434e-19  1.08420217e-19
   2.16840434e-19 -2.03287907e-19 -2.16840434e-19]]
[[ 2.16840434e-19 -2.38524478e-18 -1.30104261e-18 -4.33680869e-19
  -4.33680869e-19 -6.50521303e-19 -4.33680869e-19]
 [ 2.16840434e-19 -8.94466792e-19  2.16840434e-19 -2.16840434e-