In [1]:
import pennylane as qml
import jax
from jax import numpy as jnp
import optax
import circuit_lib as cl
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import r2_score
print(jax.devices())
jax.config.update('jax_enable_x64', True) 

[CpuDevice(id=0)]


In [2]:
df = pd.read_csv("apple_quality_balanced_dataset.csv", header = 0)

df.loc[df["Quality"] == "good", "Quality"] = 1
df.loc[df["Quality"] == "bad", "Quality"] = 0
df = df.drop(columns=["A_id"])

scaler = MinMaxScaler(feature_range=(0,1))
scaled_x = scaler.fit_transform(df.iloc[:,:7])

train = scaled_x[:850]
test = scaled_x[850:]

x_train = jnp.array(train)
y_train = jnp.array(df.iloc[:850,7].to_numpy(dtype= np.float32))

x_test = jnp.array(test)
y_test = jnp.array(df.iloc[850:,7].to_numpy(dtype= np.float32))

In [3]:
import pandas as pd
import numpy as np
from sklearn.datasets import make_classification

# Generate a synthetic dataset
n_samples = 1000  # Number of samples
n_features = 7    # Number of features
n_classes = 2     # Binary classification (0 and 1)

X, y = make_classification(
    n_samples=n_samples,
    n_features=7,    # Total number of features
    n_informative=7, # All features are informative
    n_redundant=0,   # No redundant features
    n_repeated=0,    # No repeated features
    n_classes=2,     # Binary classification
    random_state=42  # Reproducibility
)

# Convert to a DataFrame for better visualization and manipulation
columns = [f"Feature_{i+1}" for i in range(n_features)]
df = pd.DataFrame(X, columns=columns)
df['Quality'] = y  # Add the target column

# Save the dataset as a CSV file
df.to_csv("binary_classification_dataset.csv", index=False)

# Print first few rows
print(df.head())

from sklearn.preprocessing import MinMaxScaler

df.loc[df["Quality"] == "good", "Quality"] = 1
df.loc[df["Quality"] == "bad", "Quality"] = 0

scaler = MinMaxScaler(feature_range=(0,1))
scaled_x = scaler.fit_transform(df.iloc[:,:7])

x_train = scaled_x[:40]
x_test = scaled_x[850:]

y_train = df.iloc[:40,7].to_numpy(dtype= np.float32)
y_test = df.iloc[850:,7].to_numpy(dtype= np.float32)

   Feature_1  Feature_2  Feature_3  Feature_4  Feature_5  Feature_6  \
0   0.177719  -2.681202  -0.847667  -0.817180   0.335193  -1.735526   
1  -1.001741  -1.565734   1.461369   1.511607   0.897442  -3.641585   
2   1.457602  -0.156201   0.901834   1.021681   1.112249   1.511980   
3   0.156688  -4.783146  -0.484698  -0.754672  -2.574577  -0.356063   
4   0.780096  -0.592755   2.002540  -3.219841  -0.585504   2.134278   

   Feature_7  Quality  
0   2.497458        0  
1   1.453947        0  
2   1.318246        1  
3  -0.682549        0  
4   1.464577        0  


In [4]:
dev = qml.device("default.qubit", wires = 7)

ZZFeatureMap = cl.create_ZZFeatureMap(dev)

@qml.qnode(dev, interface="jax")
def EmbeddingCircuit(x):
    ZZFeatureMap(x)
    return qml.state()


@jax.jit
def embeddor(x):
    return EmbeddingCircuit(x)

embedding_map = jax.vmap(embeddor,[0])

x_train_embedding = embedding_map(x_train)
x_test_embedding = embedding_map(x_test)

In [4]:
num_reps = 12
entanglement = cl.LINEAR 

In [6]:
dev = qml.device("default.qubit", wires = 7)
ZZFeatureMap = cl.create_ZZFeatureMap(dev)
TwoLocal = cl.create_TwoLocal(dev,num_reps,entanglement)
@qml.qnode(dev, interface = "jax")
def QuantumCircuit(params, x):
    ZZFeatureMap(x)
    TwoLocal(params)
    return qml.probs(wires=[0])

In [7]:
@jax.jit
def BCE(params, data, target):
    sigmoid = 1/(1+ jnp.exp(-1*QuantumCircuit(params, data)[0]))
    sigmoid = jnp.clip(sigmoid, 1e-7, 1 - 1e-7)
    return  (target * jnp.log(sigmoid) + (1 - target) * jnp.log(1 - sigmoid))

BCE_map = jax.vmap(BCE, (None, 0, 0))

@jax.jit
def loss_fn(params, data, target):
    return jnp.mean(BCE_map(params, data, target))

In [8]:
learning_rate_schedule = optax.schedules.join_schedules(schedules = [optax.constant_schedule((0.5)**i) for i in range(6)],
                                                                     boundaries = [100, 250, 500, 700, 900])
opt = optax.lbfgs(1e-8)
max_steps = 200
@jax.jit
def f(params):
    return loss_fn(params, x_train, y_train)
value_and_grad = optax.value_and_grad_from_state(f)
@jax.jit
def optimiser(params, data, targets, print_training):
    param_history = jnp.zeros(((max_steps + 1),) + params.shape)  
    grad_history = jnp.zeros(((max_steps + 1),) + params.shape)  
    opt_state = opt.init(params)
    #Packages the arguments to be sent to the function update_step_jit 
    args = (params, opt_state, jnp.asarray(data), targets, print_training, param_history, grad_history)
    #Loops max_steps number of times
    (params, opt_state, _, _, _, param_history, grad_history) = jax.lax.fori_loop(0, max_steps+1, update_step_jit, args) 
    return params, param_history, grad_history

@jax.jit
def update_step_jit(i,args):
    # Unpacks the arguments
    params, opt_state, data, targets, print_training, param_history, grad_history = args
    param_history = param_history.at[i].set(params)
    # Gets the loss and the gradients to be applied to the parameters, by passing in the loss function and the parameters, to see how the parameters perform 
    loss_val, grads = value_and_grad(params, state = opt_state)
    grad_history = grad_history.at[i].set(grads)
    #Prints the loss every 25 steps if print_training is enable
    def print_fn():
        jax.debug.print("Step: {i}  Loss: {loss_val}", i=i, loss_val= loss_val)
    jax.lax.cond((jnp.mod(i, 50) == 0 ) & print_training, print_fn, lambda: None)
    #Applies the param updates and updates the optimiser states
    updates, opt_state = opt.update(
            grads, opt_state, params, value=loss_val, grad=grads, value_fn=f
        )
    params =  optax.apply_updates(params, updates)
    #Returns the arguments to be resupplied in the next iteration
    return (params, opt_state, data, targets, print_training, param_history, grad_history)   


In [9]:
from math import pi
params = jnp.array(np.random.default_rng().random(size = (num_reps + 1, 7)))*2*pi - pi
params, params_hist, grad_hist = optimiser(params, x_train, y_train, True)
print(loss_fn(params, x_test, y_test))

Step: 0  Loss: -0.7990769841801065


NameError: name 'x_test_embedding' is not defined

Step: 50  Loss: -0.7990769841801065
Step: 100  Loss: -0.7990769841801065
Step: 150  Loss: -0.7990769841801065
Step: 200  Loss: -0.7990769841801065


In [19]:
print(params_hist[0])
print(params_hist[500])
print(params_hist[0].shape)

[[ 1.16809215 -0.46130867  2.24321916  1.8193395  -0.41483542 -1.4582453
  -0.84811708]
 [ 0.82518368 -1.00287381 -1.61947866 -0.71348055  2.20479855 -1.64447686
   1.96090013]
 [-1.28673911  1.49313392 -0.0878167  -0.02091742  1.92579819 -0.87373926
  -2.97448706]
 [-2.88824842  0.82052263  1.35324466  2.05261903  0.84653078 -1.12220721
  -0.3323143 ]]
[[ 1.31798885 -0.46130867  2.24321916  1.8193395  -0.41483542 -1.4582453
  -0.84811708]
 [ 1.31980402 -1.00287381 -1.61947866 -0.71348055  2.20479855 -1.64447686
   1.96090013]
 [-1.04911699  1.49313392 -0.0878167  -0.02091742  1.92579819 -0.87373926
  -2.97448706]
 [-2.43080033  0.82052263  1.35324466  2.05261903  0.84653078 -1.12220721
  -0.3323143 ]]
(4, 7)


In [20]:
print(grad_hist[0])
print(grad_hist[400])

[[-8.00162597e-04  1.73472348e-18  3.46944695e-18  0.00000000e+00
   0.00000000e+00 -6.93889390e-18 -3.46944695e-18]
 [-1.34199328e-03  3.46944695e-18  0.00000000e+00  0.00000000e+00
  -3.46944695e-18  0.00000000e+00  3.46944695e-18]
 [-1.21186954e-03  0.00000000e+00  4.33680869e-19 -2.16840434e-19
  -3.46944695e-18  3.46944695e-18 -8.67361738e-19]
 [-3.98594499e-03 -6.93889390e-18 -3.46944695e-18  0.00000000e+00
   0.00000000e+00  0.00000000e+00  0.00000000e+00]]
[[ 1.55322990e-04  1.73472348e-18  0.00000000e+00  0.00000000e+00
   0.00000000e+00 -3.46944695e-18  0.00000000e+00]
 [-1.49050115e-03 -3.46944695e-18  3.46944695e-18  3.46944695e-18
   0.00000000e+00  0.00000000e+00  0.00000000e+00]
 [ 1.88159813e-05 -6.93889390e-18  0.00000000e+00  1.08420217e-18
  -3.46944695e-18  0.00000000e+00  0.00000000e+00]
 [-2.97333308e-03  0.00000000e+00  0.00000000e+00  6.93889390e-18
   3.46944695e-18  0.00000000e+00  0.00000000e+00]]


In [None]:
y_pred = np.array([QuantumCircuit(params,x_test[i]) for i in range(40)])
y_pred = y_pred[:,1]

In [74]:
print(r2_score(y_test,y_pred))

-0.054754496248425655
