In [2]:
import pennylane as qml
import jax
from jax import numpy as jnp
import optax
import circuit_lib as cl
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

print(jax.devices())

[CudaDevice(id=0)]


In [3]:
df = pd.read_csv("apple_quality_balanced_dataset.csv", header = 0)

df.loc[df["Quality"] == "good", "Quality"] = 1
df.loc[df["Quality"] == "bad", "Quality"] = 0
df = df.drop(columns=["A_id"])
train = df.iloc[:850]
test = df.iloc[850:]

x_train = jnp.array(train.iloc[:,:7].to_numpy())
y_train = jnp.array(train.iloc[:,7].to_numpy(dtype= np.float32))

x_test = jnp.array(test.iloc[:,:7].to_numpy())
y_test = jnp.array(test.iloc[:,7].to_numpy(dtype= np.float32))

In [4]:
num_reps = 3
entanglement = cl.LINEAR

dev = qml.device("lightning.gpu", wires = 7)

ZZFeatureMap = cl.create_ZZFeatureMap(dev)
TwoLocal = cl.create_TwoLocal(dev, num_reps, entanglement)

observable = qml.numpy.asarray([[0,0],
                        [0,1]])

@qml.qnode(dev, interface = "jax")
def QuantumCircuit(params, x):
    ZZFeatureMap(x)
    TwoLocal(params)
    return qml.expval(qml.Hermitian(observable, wires=[0]))

In [5]:
@jax.jit
def abse(params, data, target):
    return jnp.abs(target - QuantumCircuit(params, data))

abse_map = jax.vmap(abse, (None, 0, 0))

@jax.jit
def loss_fn(params, data, target):
    return jnp.mean(abse_map(params, data, target))

In [6]:
opt = optax.adam(1)
max_steps = 100
init_params = jnp.array(np.random.default_rng().random(size = (num_reps + 1, 7))*3)

In [7]:
@jax.jit
def optimiser(params, data, training , print_training):
    opt_state = opt.init(params)
    args = (params, opt_state, jnp.asarray(data), jnp.asarray(training),print_training)
    (params, opt_state, _, _, _) = jax.lax.fori_loop(0, max_steps+1, update_step_jit, args) 
    return params

@jax.jit
def update_step_jit(i,args):
    # Unpacks the arguments
    params, opt_state, data, targets, print_training = args
    # Gets the loss and the gradients to be applied to the parameters, by passing in the loss function and the parameters, to see how the parameters perform 
    loss_val, grads = jax.value_and_grad(loss_fn)(params, data, targets)
    #Prints the loss every 25 steps if print_training is enable
    def print_fn():
        jax.debug.print("Step: {i}  Loss: {loss_val}", i=i, loss_val=loss_val)
    jax.lax.cond((jnp.mod(i, 25) == 0 ) & print_training, print_fn, lambda: None)
    #Applies the param updates and updates the optimiser states
    updates, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    #Returns the arguments to be resupplied in the next iteration
    return (params, opt_state, data, targets, print_training)   

In [8]:
opt_params = optimiser(init_params, x_test, y_test, True)
print(loss_fn(opt_params, x_test, y_test))




Step: 0  Loss: 0.5
Step: 25  Loss: 0.5
Step: 50  Loss: 0.5
Step: 75  Loss: 0.5
Step: 100  Loss: 0.5
0.5
