In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

from importlib.util import find_spec
if find_spec("qml_hep_lhc") is None:
    import sys
    sys.path.append('..')

In [24]:
from qml_hep_lhc.data import ElectronPhoton, MNIST
from qml_hep_lhc.data.utils import tf_ds_to_numpy
from qml_hep_lhc.data.preprocessor import DataPreprocessor
from sklearn.utils import shuffle
import argparse
import wandb

import pennylane as qml
from pennylane import numpy as np
from pennylane.optimize import AdamOptimizer, GradientDescentOptimizer
import jax.numpy as jnp
import jax
from flax.training import train_state
from flax import linen as nn
import optax
from absl import logging
from jax import grad, jit, vmap
from jax import random
import tensorflow_datasets as tfds
from tqdm import tqdm
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt
import time

In [34]:
args = argparse.Namespace()

# Data
args.center_crop = 0.7
args.resize = [8,8]
args.standardize = 1
# args.power_transform = 1
args.binary_data = [0,1]
# args.percent_samples = 0.01
# args.processed = 1
args.dataset_type = 'med'
args.labels_to_categorical = 1
args.batch_size = 128

# Base Model
args.wandb = True
# args.epochs = 30
# args.learning_rate = 0.2

# Quantum CNN Parameters
# args.n_layers = 2
# args.n_qubits = 2
# args.template = 'NQubitPQC'

In [35]:
data = MNIST(args)
data.prepare_data()
data.setup()
print(data)

Binarizing data...
Binarizing data...
Center cropping...
Center cropping...
Resizing data...
Resizing data...
Standardizing data...
Converting labels to categorical...
Converting labels to categorical...

Dataset :MNIST
╒════════╤══════════════════╤═════════════════╤═════════════════╤═══════════╕
│ Data   │ Train size       │ Val size        │ Test size       │ Dims      │
╞════════╪══════════════════╪═════════════════╪═════════════════╪═══════════╡
│ X      │ (10132, 8, 8, 1) │ (2533, 8, 8, 1) │ (2115, 8, 8, 1) │ (8, 8, 1) │
├────────┼──────────────────┼─────────────────┼─────────────────┼───────────┤
│ y      │ (10132, 2)       │ (2533, 2)       │ (2115, 2)       │ (2,)      │
╘════════╧══════════════════╧═════════════════╧═════════════════╧═══════════╛

╒══════════════╤═══════╤═══════╤════════╤═══════╤══════════════════════════╕
│ Type         │   Min │   Max │   Mean │   Std │ Samples for each class   │
╞══════════════╪═══════╪═══════╪════════╪═══════╪══════════════════════════╡
│ 

In [36]:
x_train, y_train = tf_ds_to_numpy(data.train_ds)
x_val, y_val = tf_ds_to_numpy(data.val_ds)
x_test, y_test = tf_ds_to_numpy(data.test_ds)

In [38]:
if args.wandb:
     wandb.init(project='qml-hep-lhc')

[34m[1mwandb[0m: Currently logged in as: [33mgopald[0m. Use [1m`wandb login --relogin`[0m to force relogin
2022-08-15 19:05:44.673957: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory


## Hyperparameters

In [39]:
input_dims = data.config()['input_dims']

In [40]:
def get_out_shape(in_shape, k, s, padding):
    in_shape = (1,) + in_shape
    a = np.random.uniform(size = (in_shape))
    dn = jax.lax.conv_dimension_numbers(a.shape, (1,1,k[0],k[1]), ('NHWC', 'IOHW', 'NHWC'))
    out = jax.lax.conv_general_dilated_patches(lhs = a,
                                           filter_shape= k,
                                           window_strides=s,
                                           padding=padding,
                                           dimension_numbers=dn 
                                    )
    return out.shape

In [41]:
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-1):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]
 

kernel_size = (3,3)
strides = (1,1)
padding = "VALID"
clayer_sizes = [8,2]

step_size = 0.2 
num_epochs = 30

n_layers = 2
n_qubits = 4


conv_out_shape = get_out_shape(input_dims, kernel_size, strides, padding)
num_pixels = np.prod(conv_out_shape[:-1])*n_qubits
qlayer_sizes = [np.prod(kernel_size), n_layers*n_qubits*3]
clayer_sizes = [num_pixels] + clayer_sizes

params = init_network_params(qlayer_sizes, random.PRNGKey(0))
params += init_network_params(clayer_sizes, random.PRNGKey(0))

  rhs = jnp.eye(spatial_size, dtype=lhs.dtype).reshape(filter_shape * 2)


In [42]:
for i in params:
    print(i[0].shape, i[1].shape)

(24, 9) (24,)
(8, 144) (8,)
(2, 8) (2,)


## QLayers

In [43]:
dev = qml.device('default.qubit.jax', wires=n_qubits)
qubits =list(range(n_qubits))

@jax.jit
@qml.qnode(dev, interface='jax')
def circuit(inputs, w, b):
    z = jnp.dot(inputs, jnp.transpose(w)) + b
    z = jnp.tanh(z)*jnp.pi/2
    z = jnp.reshape(z, (n_layers, n_qubits, 3, -1))
    for l in range(n_layers):
        for q in qubits:
            qml.Rot(z[l,q,0], z[l,q,1], z[l,q,2], wires= q)
        if (l & 1):
            for q0, q1 in zip(qubits[1::2], qubits[2::2] + [qubits[0]]):
                qml.CZ((q0,q1))
        else:
            for q0, q1 in zip(qubits[0::2], qubits[1::2]):
                qml.CZ((q0,q1))
   
    return [qml.expval(qml.PauliZ(q)) for q in qubits]

In [44]:
def qconv(x, w, b):
    x = jnp.expand_dims(x,axis=0)
        
    dn = jax.lax.conv_dimension_numbers(x.shape, 
                                        (1,1,kernel_size[0],kernel_size[1]), 
                                        ('NHWC', 'IOHW', 'NHWC'))
    x = jax.lax.conv_general_dilated_patches(lhs = x,
                                               filter_shape= kernel_size,
                                               window_strides=strides,
                                               padding=padding,
                                               dimension_numbers=dn 
                                              )
    iters = x.shape[1:3]
    x = jnp.reshape(x, (-1, np.prod(kernel_size)))
    x = circuit(x, w, b)
    x = jnp.reshape(x, iters + (n_qubits,))
    return x

In [45]:
dev = qml.device("default.qubit", wires=n_qubits)
qnode = qml.QNode(circuit, dev)

inputs = np.random.uniform(size = (np.prod(kernel_size),))
w,b  = params[0]
drawer = qml.draw(qnode, expansion_strategy="device")
print(drawer(inputs,w,b))

0: ──Rot─╭●──Rot────╭Z─┤  <Z>
1: ──Rot─╰Z──Rot─╭●─│──┤  <Z>
2: ──Rot─╭●──Rot─╰Z─│──┤  <Z>
3: ──Rot─╰Z──Rot────╰●─┤  <Z>


## Auto-Batching Predictions

In [46]:
from jax.scipy.special import logsumexp

def relu(x):
    return jnp.maximum(0, x)

def predict(params, image):
  # per-example predictions

    qw, qb = params[0]
    activations = qconv(image, qw, qb)
    
    activations = jnp.reshape(activations, (-1))
    
    for w, b in params[1:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)

In [47]:
# This works on single examples
random_flattened_image = random.normal(random.PRNGKey(1), input_dims)
preds = predict(params,  random_flattened_image)
print(preds)

[-0.6991117  -0.68721807]


In [48]:
# Doesn't work with a batch
random_flattened_images = random.normal(random.PRNGKey(1), (2,)+ input_dims)
# try:
#     preds = predict(params, random_flattened_images)
# except TypeError:
#     print('Invalid shapes!')

In [49]:
# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None,0))

# `batched_predict` has the same call signature as `predict`
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds)

[[-0.75342685 -0.6362955 ]
 [-0.716253   -0.6705632 ]]


## Utility and loss functions

In [50]:
def one_hot(x, k, dtype=jnp.float32):
    """Create a one-hot encoding of x of size k."""
    return jnp.array(x[:, None] == jnp.arange(k), dtype)
  
def accuracy(params, images, targets):
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
    return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
    preds = batched_predict(params, images)
    return jnp.mean(optax.softmax_cross_entropy(logits=preds, labels=targets))

@jit
def update(params, x, y):
    grads = grad(loss)(params, x, y)
    return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

## Training loop

In [51]:
import time
num_epochs = 10
step_size = jnp.sqrt(step_size)

for epoch in range(num_epochs):
    start_time = time.time()
    with tqdm(tfds.as_numpy(data.train_ds), unit="batch") as tepoch:
        for x, y in tepoch:
            tepoch.set_description(f"Epoch {epoch}")
            params = update(params, x, y)
            
    epoch_time = time.time() - start_time
    train_acc = accuracy(params, x_train, y_train)
    test_acc = accuracy(params, x_val, y_val)
    
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    print("Training set accuracy {}".format(train_acc))
    print("Test set accuracy {}".format(test_acc))
    wandb.log({"accuracy": train_acc, "val_accuracy": test_acc})

Epoch 0: 100%|████████████████████████████████████████████████████████████████████| 80/80 [00:52<00:00,  1.51batch/s]
Epoch 1:  18%|███████████▋                                                       | 14/80 [00:00<00:00, 133.80batch/s]

Epoch 0 in 52.87 sec
Training set accuracy 0.8501776456832886
Test set accuracy 0.8405053019523621


Epoch 1: 100%|███████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 142.22batch/s]
Epoch 2:  18%|███████████▋                                                       | 14/80 [00:00<00:00, 134.95batch/s]

Epoch 1 in 0.57 sec
Training set accuracy 0.9289379715919495
Test set accuracy 0.9336754679679871


Epoch 2: 100%|███████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 136.09batch/s]
Epoch 3:  16%|██████████▉                                                        | 13/80 [00:00<00:00, 127.35batch/s]

Epoch 2 in 0.59 sec
Training set accuracy 0.943742573261261
Test set accuracy 0.9455190896987915


Epoch 3: 100%|███████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 137.36batch/s]
Epoch 4:  19%|████████████▌                                                      | 15/80 [00:00<00:00, 149.33batch/s]

Epoch 3 in 0.59 sec
Training set accuracy 0.9441373348236084
Test set accuracy 0.9439399838447571


Epoch 4: 100%|███████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 129.12batch/s]
Epoch 5:  18%|███████████▋                                                       | 14/80 [00:00<00:00, 133.34batch/s]

Epoch 4 in 0.62 sec
Training set accuracy 0.9490722417831421
Test set accuracy 0.9467034935951233


Epoch 5: 100%|███████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 139.38batch/s]
Epoch 6:  18%|███████████▋                                                       | 14/80 [00:00<00:00, 137.10batch/s]

Epoch 5 in 0.58 sec
Training set accuracy 0.9370311498641968
Test set accuracy 0.9328858852386475


Epoch 6: 100%|███████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 136.98batch/s]
Epoch 7:  19%|████████████▌                                                      | 15/80 [00:00<00:00, 139.99batch/s]

Epoch 6 in 0.59 sec
Training set accuracy 0.9394986033439636
Test set accuracy 0.9392024874687195


Epoch 7: 100%|███████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 132.31batch/s]
Epoch 8:  18%|███████████▋                                                       | 14/80 [00:00<00:00, 130.96batch/s]

Epoch 7 in 0.61 sec
Training set accuracy 0.8994275331497192
Test set accuracy 0.8989340662956238


Epoch 8: 100%|███████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 117.16batch/s]
Epoch 9:  18%|███████████▋                                                       | 14/80 [00:00<00:00, 132.96batch/s]

Epoch 8 in 0.69 sec
Training set accuracy 0.9470982551574707
Test set accuracy 0.9451243281364441


Epoch 9: 100%|███████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 123.98batch/s]


Epoch 9 in 0.65 sec
Training set accuracy 0.9476904273033142
Test set accuracy 0.9447295665740967


In [23]:
accuracy(params, x_test, y_test)

DeviceArray(0.9763593, dtype=float32)