In [1]:
import pennylane as qml
from pennylane import numpy as np
from sklearn.datasets import load_iris
from pennylane.optimize import AdamOptimizer, GradientDescentOptimizer

In [2]:
data = load_iris()

In [3]:
x = data['data'][:100]
y = data['target'][:100]

In [4]:
dev1 = qml.device('lightning.qubit', wires=4)

In [5]:
@qml.qnode(dev1, diff_method="parameter-shift")
def circuit(enc_params, data):
    qml.AngleEmbedding(data, wires=range(4))
    qml.BasicEntanglerLayers(enc_params, wires=range(4))
    return qml.expval(qml.PauliZ(0))

In [6]:
weight_shapes = {"weights": (4, 4)}
params = np.random.uniform(size=weight_shapes["weights"], requires_grad=True)

In [7]:
learning_rate = 0.1
epochs = 10
batch_size = 10 

In [8]:
#opt = AdamOptimizer(learning_rate, beta1=0.9, beta2=0.999)
opt = qml.QNGOptimizer(learning_rate)
# opt = GradientDescentOptimizer(learning_rate)

In [9]:
def cost_batch(params, batch, y):
    loss = 0.0
    for i, data in enumerate(batch):
        f = circuit(params, data)
        loss = loss + (y[i] - f) ** 2
    return loss / len(batch)

def cost_sample(params, x, y):
    return (y - circuit(params, x)) ** 2

In [10]:
def iterate_minibatches(data, y, batch_size):
    for start_idx in range(0, data.shape[0] - batch_size + 1, batch_size):
        idxs = slice(start_idx, start_idx + batch_size)
        yield data[idxs], y[idxs]

In [11]:


for it in range(epochs):
    #for j,batch in enumerate(iterate_minibatches(x,y, batch_size=batch_size)):
        #Xbatch, ybatch = batch
    for j, sample in enumerate(x):        
        sample = np.array(sample, requires_grad=False)
        #cost_fn = lambda p: cost_batch(p, Xbatch, ybatch)
        cost_fn = lambda p: cost_sample(p, sample, y[j])
        metric_fn = lambda p: qml.metric_tensor(circuit, approx="block-diag")(p, sample)
        params = opt.step(cost_fn, params, metric_tensor_fn=metric_fn)
        print(j, end="\r")
    
    loss = cost_batch(params,x,y)
    
    print(f"Epoch: {it} | Loss: {loss} |")

Epoch: 0 | Loss: 0.21751668366493704 |
Epoch: 1 | Loss: 0.21750394237935705 |
27


KeyboardInterrupt

