In [1]:
import numpy as np
import qiskit as qk
import matplotlib.pyplot as plt

from qiskit import Aer
from tqdm.notebook import tqdm
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from skimage.measure import block_reduce

import sys
sys.path.insert(0, '../../src/')
from neuralnetwork import *
from analysis import *

%matplotlib notebook
#%matplotlib inline
%load_ext autoreload
%autoreload 2

### Digits data

In [2]:
digits = load_digits()
three_idx = (digits.target == 3)
six_idx = (digits.target == 6)

In [3]:
threes = digits.data[three_idx]
sixes =  digits.data[six_idx]

x = np.concatenate((threes, sixes))
x = [block_reduce(image.reshape(8, 8), (2, 2), func=np.mean).reshape(-1) for image in x]
x = np.array(x)

y = np.concatenate((np.zeros(len(threes)), np.ones(len(sixes)))).reshape(-1,1)

In [4]:
np.random.seed(42)
x_train, x_test, y_train, y_test = train_test_split(x, y, train_size=0.2)
print(x_train.shape)

(72, 16)


### Network

In [5]:
np.random.seed(42)
backend = Aer.get_backend('qasm_simulator')

network1 = sequential_qnn(q_bits=[16, 4],
                         dim=[16, 4, 1],
                         reps = 1,
                         backend=backend,
                         shots=10000,
                         lr=0.1
                         )

### Training

In [6]:
network1.train(x_train, y_train, epochs=100, verbose=True)

  0%|          | 0/100 [00:00<?, ?it/s]

epoch: 0, loss: 0.2582664602777778
epoch: 1, loss: 0.24961145861111111
epoch: 2, loss: 0.2438326081944444
epoch: 3, loss: 0.2432853488888889
epoch: 4, loss: 0.24333207722222225
epoch: 5, loss: 0.24065142152777774
epoch: 6, loss: 0.23672300111111116
epoch: 7, loss: 0.23485696888888885
epoch: 8, loss: 0.23332320305555557
epoch: 9, loss: 0.22634395444444444
epoch: 10, loss: 0.22324502680555558
epoch: 11, loss: 0.22077186611111113
epoch: 12, loss: 0.21846342458333334
epoch: 13, loss: 0.2162096177777778
epoch: 14, loss: 0.21250792805555557
epoch: 15, loss: 0.2095244423611111
epoch: 16, loss: 0.2068099634722222
epoch: 17, loss: 0.20985439972222222
epoch: 18, loss: 0.20775253777777777
epoch: 19, loss: 0.20386455569444445
epoch: 20, loss: 0.20379286
epoch: 21, loss: 0.20385990305555557
epoch: 22, loss: 0.20048300694444443
epoch: 23, loss: 0.20181725416666665
epoch: 24, loss: 0.19925808430555556
epoch: 25, loss: 0.19683705444444444
epoch: 26, loss: 0.1983087063888889
epoch: 27, loss: 0.19581106

In [7]:
saver(network1, data_path("QNN_digits_16qb"))

In [8]:
np.random.seed(42)
network2 = sequential_dnn(dim=[16, 4, 1], lr=0.1)

In [9]:
network2.train(x_train, y_train, epochs=100, verbose=True)

  0%|          | 0/100 [00:00<?, ?it/s]

epoch: 0, loss: 0.21717602763437535
epoch: 1, loss: 0.16388157680374751
epoch: 2, loss: 0.12676802310603488
epoch: 3, loss: 0.11055128901867321
epoch: 4, loss: 0.08722395795675966
epoch: 5, loss: 0.06505103617194709
epoch: 6, loss: 0.05474521315133621
epoch: 7, loss: 0.04831287441906112
epoch: 8, loss: 0.043696138752481115
epoch: 9, loss: 0.039763530194833335
epoch: 10, loss: 0.03597268839270532
epoch: 11, loss: 0.032183066809473804
epoch: 12, loss: 0.028425222629953173
epoch: 13, loss: 0.024740829791258166
epoch: 14, loss: 0.021221140539797657
epoch: 15, loss: 0.018045989692695306
epoch: 16, loss: 0.015337715725036295
epoch: 17, loss: 0.013099251909122483
epoch: 18, loss: 0.011277673916117288
epoch: 19, loss: 0.009808049157418378
epoch: 20, loss: 0.008627184716671673
epoch: 21, loss: 0.007678212534981987
epoch: 22, loss: 0.00691242465441699
epoch: 23, loss: 0.006289623189577449
epoch: 24, loss: 0.005777554529188249
epoch: 25, loss: 0.005350882306110686
epoch: 26, loss: 0.0049900202315

In [10]:
y_pred = network2.predict(x_train)

In [12]:
print(np.mean(np.round(y_pred) != y_train))

[[1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [0.]
 [0.]
 [1.]
 [1.]
 [1.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [1.]
 [0.]
 [1.]
 [0.]
 [0.]
 [1.]
 [1.]
 [0.]
 [1.]
 [0.]
 [1.]
 [0.]
 [1.]
 [0.]
 [1.]
 [1.]
 [1.]
 [1.]
 [0.]
 [0.]
 [0.]
 [1.]
 [1.]
 [1.]
 [0.]
 [0.]
 [0.]
 [1.]
 [1.]
 [0.]
 [1.]
 [0.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [1.]
 [1.]
 [0.]
 [1.]
 [0.]
 [1.]
 [0.]
 [0.]
 [1.]
 [1.]
 [0.]]
