In [None]:
import pennylane as qml
import jax
from jax import numpy as jnp
import optax
import circuit_lib as cl
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import r2_score
print(jax.devices())
jax.config.update('jax_enable_x64', True) 

In [5]:
df = pd.read_csv("apple_quality_balanced_dataset.csv", header = 0)

df.loc[df["Quality"] == "good", "Quality"] = 1
df.loc[df["Quality"] == "bad", "Quality"] = 0
df = df.drop(columns=["A_id"])

scaler = MinMaxScaler(feature_range=(0,1))
scaled_x = scaler.fit_transform(df.iloc[:,:7])

train = scaled_x[:850]
test = scaled_x[850:]

x_train = jnp.array(train)
y_train = jnp.array(df.iloc[:850,7].to_numpy(dtype= np.float32))

x_test = jnp.array(test)
y_test = jnp.array(df.iloc[850:,7].to_numpy(dtype= np.float32))

In [None]:
import pandas as pd
import numpy as np
from sklearn.datasets import make_classification

# Generate a synthetic dataset
n_samples = 1000  # Number of samples
n_features = 7    # Number of features
n_classes = 2     # Binary classification (0 and 1)

X, y = make_classification(
    n_samples=n_samples,
    n_features=7,    # Total number of features
    n_informative=7, # All features are informative
    n_redundant=0,   # No redundant features
    n_repeated=0,    # No repeated features
    n_classes=2,     # Binary classification
    random_state=42  # Reproducibility
)

# Convert to a DataFrame for better visualization and manipulation
columns = [f"Feature_{i+1}" for i in range(n_features)]
df = pd.DataFrame(X, columns=columns)
df['Quality'] = y  # Add the target column

# Save the dataset as a CSV file
df.to_csv("binary_classification_dataset.csv", index=False)

# Print first few rows
print(df.head())

from sklearn.preprocessing import MinMaxScaler

df.loc[df["Quality"] == "good", "Quality"] = 1
df.loc[df["Quality"] == "bad", "Quality"] = 0

scaler = MinMaxScaler(feature_range=(0,1))
scaled_x = scaler.fit_transform(df.iloc[:,:7])

x_train = scaled_x[:40]
x_test = scaled_x[850:]

y_train = df.iloc[:40,7].to_numpy(dtype= np.float32)
y_test = df.iloc[850:,7].to_numpy(dtype= np.float32)

In [None]:
dev = qml.device("default.qubit", wires = 7)

ZZFeatureMap = cl.create_ZZFeatureMap(dev)

@qml.qnode(dev, interface="jax")
def EmbeddingCircuit(x):
    ZZFeatureMap(x)
    return qml.state()


@jax.jit
def embeddor(x):
    return EmbeddingCircuit(x)

embedding_map = jax.vmap(embeddor,[0])

x_train_embedding = embedding_map(x_train)
x_test_embedding = embedding_map(x_test)

In [7]:
num_reps = 6
entanglement = cl.LINEAR 

In [None]:
dev = qml.device("default.qubit", wires = 7)
ZZFeatureMap = cl.create_ZZFeatureMap(dev)

def phase_rotation(params):
    for i in range(7):
        qml.RZ(params[i],i)

TwoLocal = cl.create_TwoLocal(dev,num_reps,entanglement)
@qml.qnode(dev, interface = "jax")
def QuantumCircuit(params, x):
    ZZFeatureMap(x)
    phase_rotation(params[0:7])
    for layer in range(1, num_reps + 1):
        for i in range(7 - 1):
            qml.CNOT(wires = [i, i+1])
        phase_rotation(params[7*layer : 7*(layer + 1)])
    return qml.probs(wires=[0])

In [9]:
@jax.jit
def BCE(params, data, target):
    sigmoid = 1/(1+ jnp.exp(-1*QuantumCircuit(params, data)[0]))
    sigmoid = jnp.clip(sigmoid, 1e-7, 1 - 1e-7)
    return  -1*(target * jnp.log(sigmoid) + (1 - target) * jnp.log(1 - sigmoid))

BCE_map = jax.vmap(BCE, (None, 0, 0))

@jax.jit
def loss_fn(params, data, target):
    return jnp.mean(BCE_map(params, data, target))

In [14]:
learning_rate_schedule = optax.schedules.join_schedules(schedules = [optax.constant_schedule((0.5)**i) for i in range(6)],
                                                                     boundaries = [100, 250, 500, 700, 900])
opt = optax.adam(1e-4)
max_steps = 200
@jax.jit
def optimiser(params, data, targets, print_training):
    param_history = jnp.zeros(((max_steps + 1),) + params.shape)  
    grad_history = jnp.zeros(((max_steps + 1),) + params.shape)  
    opt_state = opt.init(params)
    #Packages the arguments to be sent to the function update_step_jit 
    args = (params, opt_state, jnp.asarray(data), targets, print_training, param_history, grad_history)
    #Loops max_steps number of times
    (params, opt_state, _, _, _, param_history, grad_history) = jax.lax.fori_loop(0, max_steps+1, update_step_jit, args) 
    return params, param_history, grad_history

@jax.jit
def update_step_jit(i,args):
    # Unpacks the arguments
    params, opt_state, data, targets, print_training, param_history, grad_history = args
    param_history = param_history.at[i].set(params)
    # Gets the loss and the gradients to be applied to the parameters, by passing in the loss function and the parameters, to see how the parameters perform 
    loss_val, grads = jax.value_and_grad(loss_fn)(params,data,targets)
    grad_history = grad_history.at[i].set(grads)
    #Prints the loss every 25 steps if print_training is enable
    def print_fn():
        jax.debug.print("Step: {i}  Loss: {loss_val}", i=i, loss_val= loss_val)
    jax.lax.cond((jnp.mod(i, 50) == 0 ) & print_training, print_fn, lambda: None)
    #Applies the param updates and updates the optimiser states
    updates, opt_state = opt.update(
            grads, opt_state, params
        )
    params =  optax.apply_updates(params, updates)
    #Returns the arguments to be resupplied in the next iteration
    return (params, opt_state, data, targets, print_training, param_history, grad_history)   


In [None]:
from math import pi
params = jnp.array(np.random.default_rng().random(size = ((num_reps + 1 )* 7,)))*2*pi - pi
params, params_hist, grad_hist = optimiser(params, x_train, y_train, True)
print(loss_fn(params, x_test, y_test))