In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

from importlib.util import find_spec
if find_spec("qml_hep_lhc") is None:
    import sys
    sys.path.append('..')

In [2]:
from qml_hep_lhc.data import ElectronPhoton, MNIST, QuarkGluon
from qml_hep_lhc.data.utils import tf_ds_to_numpy
from qml_hep_lhc.data.preprocessor import DataPreprocessor
from sklearn.utils import shuffle
import argparse
import wandb

import pennylane as qml
from pennylane import numpy as np
from pennylane.optimize import AdamOptimizer, GradientDescentOptimizer
import jax.numpy as jnp
import jax
from flax.training import train_state
from flax import linen as nn
import optax
from absl import logging
from jax import grad, jit, vmap
from jax import random
import tensorflow_datasets as tfds
from tqdm import tqdm
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt
import time

2022-08-16 09:09:09.421609: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-08-16 09:09:09.421696: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


In [123]:
args = argparse.Namespace()

# Data
args.center_crop = 0.7
args.resize = [8,8]
args.standardize = 1
# args.power_transform = 1
args.binary_data = [0,1]
# args.percent_samples = 0.01
# args.processed = 1
# args.dataset_type = 'med'
args.labels_to_categorical = 1
args.batch_size = 128

# Base Model
args.wandb = False
# args.epochs = 30
# args.learning_rate = 0.2

# Quantum CNN Parameters
# args.n_layers = 2
# args.n_qubits = 2
# args.template = 'NQubitPQC'

In [144]:
data = MNIST(args)
data.prepare_data()
data.setup()
print(data)

Binarizing data...
Binarizing data...
Center cropping...
Center cropping...
Resizing data...
Resizing data...
Standardizing data...
Converting labels to categorical...
Converting labels to categorical...

Dataset :MNIST
╒════════╤══════════════════╤═════════════════╤═════════════════╤═══════════╕
│ Data   │ Train size       │ Val size        │ Test size       │ Dims      │
╞════════╪══════════════════╪═════════════════╪═════════════════╪═══════════╡
│ X      │ (10132, 8, 8, 1) │ (2533, 8, 8, 1) │ (2115, 8, 8, 1) │ (8, 8, 1) │
├────────┼──────────────────┼─────────────────┼─────────────────┼───────────┤
│ y      │ (10132, 2)       │ (2533, 2)       │ (2115, 2)       │ (2,)      │
╘════════╧══════════════════╧═════════════════╧═════════════════╧═══════════╛

╒══════════════╤═══════╤═══════╤════════╤═══════╤══════════════════════════╕
│ Type         │   Min │   Max │   Mean │   Std │ Samples for each class   │
╞══════════════╪═══════╪═══════╪════════╪═══════╪══════════════════════════╡
│ 

In [145]:
if args.wandb:
     wandb.init(project='qml-hep-lhc')

## Hyperparameters

In [146]:
input_dims = data.config()['input_dims']

In [147]:
def get_out_shape(in_shape, k, s, padding):
    in_shape = (1,) + in_shape
    a = np.random.uniform(size = (in_shape))
    dn = jax.lax.conv_dimension_numbers(a.shape, (1,1,k[0],k[1]), ('NHWC', 'IOHW', 'NHWC'))
    out = jax.lax.conv_general_dilated_patches(lhs = a,
                                           filter_shape= k,
                                           window_strides=s,
                                           padding=padding,
                                           dimension_numbers=dn 
                                    )
    return out.shape

In [168]:
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-1):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]
 

kernel_size = (3,3)
strides = (2,2)
padding = "VALID"
clayer_sizes = [8,2]

step_size = 0.1 
num_epochs = 30

n_layers = 1
n_qubits = 1


conv_out_shape = get_out_shape(input_dims, kernel_size, strides, padding)
num_pixels = np.prod(conv_out_shape[:-1])*n_qubits
qlayer_sizes = [np.prod(kernel_size), n_layers*n_qubits*3]
clayer_sizes = [num_pixels] + clayer_sizes

params = init_network_params(qlayer_sizes, random.PRNGKey(0))
params += init_network_params(clayer_sizes, random.PRNGKey(0))

  rhs = jnp.eye(spatial_size, dtype=lhs.dtype).reshape(filter_shape * 2)


In [169]:
for i in params:
    print(i[0].shape, i[1].shape)

(3, 9) (3,)
(8, 9) (8,)
(2, 8) (2,)


## QLayers

In [170]:
dev = qml.device('default.qubit.jax', wires=n_qubits)
qubits =list(range(n_qubits))

@jax.jit
@qml.qnode(dev, interface='jax')
def circuit(inputs, w, b):
    z = jnp.dot(inputs, jnp.transpose(w)) + b
    z = jnp.tanh(z)*jnp.pi/2
    z = jnp.reshape(z, (n_layers, n_qubits, 3, -1))
    for l in range(n_layers):
        for q in qubits:
            qml.Rot(z[l,q,0], z[l,q,1], z[l,q,2], wires= q)
        if (l & 1):
            for q0, q1 in zip(qubits[1::2], qubits[2::2] + [qubits[0]]):
                qml.CZ((q0,q1))
        else:
            for q0, q1 in zip(qubits[0::2], qubits[1::2]):
                qml.CZ((q0,q1))
   
    return [qml.expval(qml.PauliZ(q)) for q in qubits]

In [171]:
def qconv(x, w, b):
    x = jnp.expand_dims(x,axis=0)
        
    dn = jax.lax.conv_dimension_numbers(x.shape, 
                                        (1,1,kernel_size[0],kernel_size[1]), 
                                        ('NHWC', 'IOHW', 'NHWC'))
    x = jax.lax.conv_general_dilated_patches(lhs = x,
                                               filter_shape= kernel_size,
                                               window_strides=strides,
                                               padding=padding,
                                               dimension_numbers=dn 
                                              )
    iters = x.shape[1:3]
    x = jnp.reshape(x, (-1, np.prod(kernel_size)))
    x = circuit(x, w, b)
    x = jnp.reshape(x, iters + (n_qubits,))
    return x

In [172]:
dev = qml.device("default.qubit", wires=n_qubits)
qnode = qml.QNode(circuit, dev)

inputs = np.random.uniform(size = (np.prod(kernel_size),))
w,b  = params[0]
drawer = qml.draw(qnode, expansion_strategy="device")
print(drawer(inputs,w,b))

0: ──Rot─┤  <Z>


## Auto-Batching Predictions

In [173]:
from jax.scipy.special import logsumexp

def relu(x):
    return jnp.maximum(0, x)

def forward(params, image):
  # per-example predictions

    qw, qb = params[0]
    activations = qconv(image, qw, qb)
    
    activations = jnp.reshape(activations, (-1))
    
    for w, b in params[1:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)

In [174]:
# This works on single examples
random_flattened_image = random.normal(random.PRNGKey(1), input_dims)
preds = forward(params,  random_flattened_image)
print(preds)

[-0.72416043 -0.6630668 ]


In [175]:
# Doesn't work with a batch
random_flattened_images = random.normal(random.PRNGKey(1), (2,)+ input_dims)
# try:
#     preds = predict(params, random_flattened_images)
# except TypeError:
#     print('Invalid shapes!')

In [176]:
# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_forward = vmap(forward, in_axes=(None,0))

# `batched_predict` has the same call signature as `predict`
batched_preds = batched_forward(params, random_flattened_images)
print(batched_preds)

[[-0.7249526  -0.6623222 ]
 [-0.7243084  -0.66292775]]


## Utility and loss functions

In [177]:
def accuracy(y_true, y_pred):
    target_class = jnp.argmax(y_true, axis=1)
    predicted_class = jnp.argmax(y_pred, axis=1)
    return jnp.mean(predicted_class == target_class)


def loss_fn(params, images, targets):
    preds = batched_forward(params, images)
    loss_value = jnp.mean(optax.softmax_cross_entropy(logits=preds, labels=targets))
    return loss_value, preds

@jit
def update(opt_state, params, x, y):
    (loss_value, preds), grads = jax.value_and_grad(loss_fn, has_aux=True)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state


def step(params,x,y):
    loss_value, preds = loss_fn(params, x, y)
    acc = accuracy(y, preds)
    return loss, acc

def evaluate(params, ds):
    res = jnp.array([step(params, x,y)  for x,y in tfds.as_numpy(ds)])
    avg_loss = jnp.mean(res[:,0])
    avg_acc = jnp.mean(res[:,1])
    return avg_loss, avg_acc

## Training loop

In [178]:
# Defining an optimizer in Jax
optimizer = optax.adam(step_size)
opt_state = optimizer.init(params)

In [181]:
import time
num_epochs = 10

for epoch in range(num_epochs):
    start_time = time.time()
    with tqdm(tfds.as_numpy(data.train_ds), unit="batch") as tepoch:
        for x, y in tepoch:
            tepoch.set_description(f"Epoch {epoch}")
            params, opt_state = update(opt_state, params, x, y)
        
        epoch_time = time.time() - start_time
        
        train_loss, train_acc = evaluate(params, data.train_ds)
        val_loss, val_acc = evaluate(params, data.val_ds)
        print('loss: {} - acc: {} - val_loss: {} - val_acc: {}'.format(train_loss,train_acc, val_loss, val_acc))
#     print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
#     print("Training set accuracy {}".format(train_acc))
#     print("Test set accuracy {}".format(test_acc))
    
    if args.wandb:
        wandb.log({"accuracy": train_acc, "val_accuracy": test_acc, 'loss':train_loss, 'val_loss':val_loss})

step_size = jnp.sqrt(step_size)

Epoch 0: 100%|███████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 331.96batch/s]
Epoch 1:  42%|████████████████████████████▍                                      | 34/80 [00:00<00:00, 336.50batch/s]

loss: 0.6695100665092468 - acc: 0.986328125 - val_loss: 0.6695100665092468 - val_acc: 0.9917969107627869


Epoch 1: 100%|███████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 255.60batch/s]
Epoch 2:  41%|███████████████████████████▋                                       | 33/80 [00:00<00:00, 329.06batch/s]

loss: 0.6695100665092468 - acc: 0.986035168170929 - val_loss: 0.6695100665092468 - val_acc: 0.9884630441665649


Epoch 2: 100%|███████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 316.98batch/s]
Epoch 3:  44%|█████████████████████████████▎                                     | 35/80 [00:00<00:00, 347.58batch/s]

loss: 0.6695100665092468 - acc: 0.986621081829071 - val_loss: 0.6695100665092468 - val_acc: 0.9913018345832825


Epoch 3: 100%|███████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 342.30batch/s]
Epoch 4:  40%|██████████████████████████▊                                        | 32/80 [00:00<00:00, 316.69batch/s]

loss: 0.6695100665092468 - acc: 0.9862304925918579 - val_loss: 0.6695100665092468 - val_acc: 0.9909111857414246


Epoch 4: 100%|███████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 308.76batch/s]
Epoch 5:  42%|████████████████████████████▍                                      | 34/80 [00:00<00:00, 338.03batch/s]

loss: 0.6695100665092468 - acc: 0.9849609732627869 - val_loss: 0.6695100665092468 - val_acc: 0.9862236976623535


Epoch 5: 100%|███████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 336.22batch/s]
Epoch 6:  44%|█████████████████████████████▎                                     | 35/80 [00:00<00:00, 342.34batch/s]

loss: 0.6695100665092468 - acc: 0.98583984375 - val_loss: 0.6695100665092468 - val_acc: 0.9901299476623535


Epoch 6: 100%|███████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 306.03batch/s]
Epoch 7:  35%|███████████████████████▍                                           | 28/80 [00:00<00:00, 279.86batch/s]

loss: 0.6695100665092468 - acc: 0.9852734804153442 - val_loss: 0.6695100665092468 - val_acc: 0.991406261920929


Epoch 7: 100%|███████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 300.17batch/s]
Epoch 8:  42%|████████████████████████████▍                                      | 34/80 [00:00<00:00, 336.58batch/s]

loss: 0.6695100665092468 - acc: 0.9869140982627869 - val_loss: 0.6695100665092468 - val_acc: 0.990234375


Epoch 8: 100%|███████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 335.95batch/s]
Epoch 9:  34%|██████████████████████▌                                            | 27/80 [00:00<00:00, 267.32batch/s]

loss: 0.6695100665092468 - acc: 0.9857031106948853 - val_loss: 0.6695100665092468 - val_acc: 0.98828125


Epoch 9: 100%|███████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 311.20batch/s]


loss: 0.6695100665092468 - acc: 0.986328125 - val_loss: 0.6695100665092468 - val_acc: 0.9881768226623535


In [54]:
accuracy(params, x_test, y_test)

DeviceArray(0.9607565, dtype=float32)

In [55]:
if args.wandb:
    wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▆██████▇█████▁███▅▅█
val_accuracy,▆██████▇█████▁███▅▅█

0,1
accuracy,0.9544
val_accuracy,0.9546
