In [66]:
import pennylane as qml
import jax
from jax import numpy as jnp
import optax
import circuit_lib as cl
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
print(jax.devices())
jax.config.update('jax_enable_x64', True) 

[CudaDevice(id=0)]


In [67]:
df = pd.read_csv("apple_quality_balanced_dataset.csv", header = 0)

df.loc[df["Quality"] == "good", "Quality"] = 1
df.loc[df["Quality"] == "bad", "Quality"] = 0
df = df.drop(columns=["A_id"])

scaler = MinMaxScaler(feature_range=(-1,1))
scaled_x = scaler.fit_transform(df.iloc[:,:7])

train = scaled_x[:850]
test = scaled_x[850:]

x_train = jnp.array(train)
y_train = jnp.array(df.iloc[:850,7].to_numpy(dtype= np.float32))

x_test = jnp.array(test)
y_test = jnp.array(df.iloc[850:,7].to_numpy(dtype= np.float32))

In [68]:
dev = qml.device("default.qubit", wires = 7)

ZZFeatureMap = cl.create_ZZFeatureMap(dev)

@qml.qnode(dev, interface="jax")
def EmbeddingCircuit(x):
    ZZFeatureMap(x)
    return qml.state()


@jax.jit
def embeddor(x):
    return EmbeddingCircuit(x)

embedding_map = jax.vmap(embeddor,[0])

x_train_embedding = embedding_map(x_train)
x_test_embedding = embedding_map(x_test)

In [78]:
num_reps = 6
entanglement = cl.CIRCULAR   

wires = dev.wires
num_wires = len(wires)
def entanglement_layer():
    qml.CNOT(wires = [wires[num_wires - 1], wires[0]])
    for i in range(num_wires - 1):
        qml.CNOT(wires = [wires[i], wires[i+1]])

def phase_rotation(params):
    for i in range(num_wires):
        qml.PhaseShift(params[i], wires[i])
def TwoLocal(params):
    phase_rotation(params[0])
    for layer in range(1, num_reps + 1):
        entanglement_layer()
        phase_rotation(params[layer])


In [79]:
observable = qml.numpy.asarray([[0,0],
                        [0,1]])

observable = qml.numpy.array([
    [1, 0, 0, 0, 0, 0, 0, 0],
    [0, 1, 0, 0, 0, 0, 0, 0],
    [0, 0, 1, 0, 0, 0, 0, 0],
    [0, 0, 0, 1, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0]
])

@qml.qnode(dev, interface = "jax")
def QuantumCircuit(params, embedding):
    qml.StatePrep(embedding, wires = [0,1,2,3,4,5,6], normalize= True)
    TwoLocal(params)
    return qml.expval(qml.Hermitian(observable, wires=[0,1,2]))


In [80]:
@jax.jit
def mse(params, data, target):
    return (target - QuantumCircuit(params, data))**2

mse_map = jax.vmap(mse, (None, 0, 0))

@jax.jit
def loss_fn(params, data, target):
    return jnp.mean(mse_map(params, data, target))

In [98]:
opt = optax.adam(0.1)
max_steps = 500
@jax.jit
def optimiser(params, data, targets, print_training):
    param_history = jnp.zeros(((max_steps + 1),) + params.shape)  
    opt_state = opt.init(params)
    #Packages the arguments to be sent to the function update_step_jit 
    args = (params, opt_state, jnp.asarray(data), targets, print_training, param_history)
    #Loops max_steps number of times
    (params, opt_state, _, _, _, param_history) = jax.lax.fori_loop(0, max_steps+1, update_step_jit, args) 
    return params, param_history

@jax.jit
def update_step_jit(i,args):
    # Unpacks the arguments
    params, opt_state, data, targets, print_training, param_history = args
    param_history = param_history.at[i].set(params)
    # Gets the loss and the gradients to be applied to the parameters, by passing in the loss function and the parameters, to see how the parameters perform 
    loss_val, grads = jax.value_and_grad(loss_fn)(params, data, targets)
    #Prints the loss every 25 steps if print_training is enable
    def print_fn():
        jax.debug.print("Step: {i}  Loss: {loss_val} Gradient norm: {g}", g= jnp.linalg.norm(grads), i=i, loss_val=loss_val)
    jax.lax.cond((jnp.mod(i, 50) == 0 ) & print_training, print_fn, lambda: None)
    #Applies the param updates and updates the optimiser states
    updates, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    #Returns the arguments to be resupplied in the next iteration
    return (params, opt_state, data, targets, print_training, param_history)   


In [99]:
init_params = jnp.ones(shape=(num_reps + 1, 7)
)
opt_params, params_hist = optimiser(init_params, x_train_embedding, y_train, True)
print(loss_fn(opt_params, x_test_embedding, y_test))

Step: 0  Loss: 0.24999999999999997 Gradient norm: 8.245640863605515e-18
Step: 50  Loss: 0.24999999999999997 Gradient norm: 7.21137088772849e-18
Step: 100  Loss: 0.24999999999999997 Gradient norm: 7.224399572041235e-18
Step: 150  Loss: 0.24999999999999997 Gradient norm: 6.86052430991425e-18
Step: 200  Loss: 0.24999999999999997 Gradient norm: 5.642019764257083e-18
Step: 250  Loss: 0.24999999999999997 Gradient norm: 7.126103841197012e-18
Step: 300  Loss: 0.24999999999999997 Gradient norm: 6.573520543531243e-18
Step: 350  Loss: 0.24999999999999997 Gradient norm: 7.61724367186636e-18
Step: 400  Loss: 0.24999999999999997 Gradient norm: 6.1101309791293226e-18
Step: 450  Loss: 0.24999999999999997 Gradient norm: 6.952433221959954e-18
Step: 500  Loss: 0.24999999999999997 Gradient norm: 6.3033093734794044e-18
0.25


In [None]:
print(params_hist[0])
print(params_hist[100])
print(params_hist[0].shape)

[[-0.12520351 -0.33444239  0.62504313  0.26925998  0.19504671 -0.23168405
  -0.56353509]
 [ 0.20597923 -0.017325   -0.31597895  0.4934467  -0.65271617 -0.3715017
  -0.22335656]
 [ 0.1572723  -0.17718732 -0.2554929   0.17977453  0.01274453  0.80102742
  -0.45218052]
 [ 0.08611685  0.06476738  0.08938308 -0.77294991 -0.41919139 -0.04938587
  -0.45253485]
 [-0.34219927  0.23687799 -0.57380386 -0.05112175  0.451848   -0.30871563
  -0.44209978]
 [-0.85646676 -0.31312522 -0.06690812 -0.00941428 -0.35953292  0.14855119
   0.11189541]
 [ 0.24312609 -0.83552342 -0.32013258 -0.22701744  0.16310733 -0.21645392
   0.12374261]]
[[-0.12520351 -0.33444239  0.62504313  0.26925998  0.19504671 -0.23168405
  -0.56353509]
 [ 0.20597923 -0.017325   -0.31597895  0.4934467  -0.65271617 -0.3715017
  -0.22335656]
 [ 0.1572723  -0.17718732 -0.2554929   0.17977453  0.01274453  0.80102742
  -0.45218052]
 [ 0.08611685  0.06476738  0.08938308 -0.77294991 -0.41919139 -0.04938587
  -0.45253485]
 [-0.34219927  0.23687