In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

from importlib.util import find_spec
if find_spec("qml_hep_lhc") is None:
    import sys
    sys.path.append('..')

In [2]:
from qml_hep_lhc.data import ElectronPhoton, MNIST
from qml_hep_lhc.data.utils import tf_ds_to_numpy
import argparse

import pennylane as qml
from pennylane import numpy as np
from pennylane.optimize import AdamOptimizer, GradientDescentOptimizer
import jax.numpy as jnp
import jax
from flax.training import train_state
from flax import linen as nn
import optax
from absl import logging
from jax import grad, jit, vmap
from jax import random
import tensorflow_datasets as tfds

import matplotlib.pyplot as plt
import time

2022-08-15 11:00:24.646032: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-08-15 11:00:24.646066: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


In [35]:
args = argparse.Namespace()

# Data
# args.center_crop = 0.7
# args.resize = [4,4]
args.standardize = 1
# args.binary_data = [0,1]
# args.percent_samples = 0.01
# args.processed = 1
args.dataset_type = 'med'
args.labels_to_categorical = 1

# Base Model
args.wandb = False
args.batch_size = 64
args.epochs = 30
args.learning_rate = 0.2

# Quantum CNN Parameters
args.n_layers = 2
args.n_qubits = 2
args.template = 'NQubitPQC'

In [36]:
data = ElectronPhoton(args)
data.prepare_data()
data.setup()
print(data)

Standardizing data...
Converting labels to categorical...
Converting labels to categorical...

Dataset :Electron Photon med
╒════════╤═══════════════════╤═══════════════════╤═══════════════════╤═════════════╕
│ Data   │ Train size        │ Val size          │ Test size         │ Dims        │
╞════════╪═══════════════════╪═══════════════════╪═══════════════════╪═════════════╡
│ X      │ (7200, 32, 32, 1) │ (1800, 32, 32, 1) │ (1000, 32, 32, 1) │ (32, 32, 1) │
├────────┼───────────────────┼───────────────────┼───────────────────┼─────────────┤
│ y      │ (7200, 2)         │ (1800, 2)         │ (1000, 2)         │ (2,)        │
╘════════╧═══════════════════╧═══════════════════╧═══════════════════╧═════════════╛

╒══════════════╤═══════╤════════╤════════╤═══════╤══════════════════════════╕
│ Type         │   Min │    Max │   Mean │   Std │ Samples for each class   │
╞══════════════╪═══════╪════════╪════════╪═══════╪══════════════════════════╡
│ Train Images │ -2.88 │  94.86 │      0 │  0.

In [37]:
num_pixels = 32*32
x_train, y_train = tf_ds_to_numpy(data.train_ds)
x_val, y_val = tf_ds_to_numpy(data.val_ds)
x_train = x_train.reshape(-1, num_pixels)
x_val = x_val.reshape(-1, num_pixels)
# y_train = 2.0*y_train-1.0
# y_val = 2.0*y_val-1.0

In [38]:
x_train.shape, y_train.shape

((7200, 1024), (7200, 2))

In [39]:
x_train = jnp.array(x_train)
y_train = jnp.array(y_train)
x_val = jnp.array(x_val)
y_val = jnp.array(y_val)

In [40]:
dev = qml.device('default.qubit.jax', wires=1)

@jax.jit
@qml.qnode(dev, interface='jax')
def circuit(inputs, w, b):
    z = jnp.dot(w, inputs) + b
    qml.Rot(z[0],z[1],z[2], wires=0)
    return qml.expval(qml.PauliZ(0))

In [41]:
vcircuit = jax.vmap(circuit)

## Hyperparameters

In [50]:
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-1):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [num_pixels, 128, 64, 16, 3]
step_size = 0.1
num_epochs = 30
batch_size = 128
n_targets = 2
params = init_network_params(layer_sizes, random.PRNGKey(0))

In [51]:
for i in params:
    print(i[0].shape, i[1].shape)

(128, 1024) (128,)
(64, 128) (64,)
(16, 64) (16,)
(3, 16) (3,)


## Auto-Batching Predictions

In [52]:
from jax.scipy.special import logsumexp

def relu(x):
    return jnp.maximum(0, x)

def predict(params, image):
  # per-example predictions
    activations = image
    for w, b in params[:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)
    
    qw, qb = params[-1]
    logits = circuit(activations, qw, qb)
    logits = jnp.array([nn.sigmoid(logits)])
    return jnp.concatenate((logits, 1-logits), -1)

In [53]:
# This works on single examples
random_flattened_image = random.normal(random.PRNGKey(1), (num_pixels,))
preds = predict(params,  random_flattened_image)
print(preds)

[0.72555614 0.27444386]


In [54]:
# Doesn't work with a batch
random_flattened_images = random.normal(random.PRNGKey(1), (10, num_pixels))
try:
    preds = predict(params, random_flattened_images)
except TypeError:
    print('Invalid shapes!')

Invalid shapes!


In [55]:
# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None,0))

# `batched_predict` has the same call signature as `predict`
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds)

[[0.72749335 0.27250665]
 [0.72137743 0.27862257]
 [0.7240357  0.27596432]
 [0.71940887 0.28059113]
 [0.728792   0.271208  ]
 [0.7293075  0.27069253]
 [0.70957965 0.29042035]
 [0.7283114  0.27168858]
 [0.71942383 0.28057617]
 [0.7309015  0.26909852]]


## Utility and loss functions

In [56]:
def one_hot(x, k, dtype=jnp.float32):
    """Create a one-hot encoding of x of size k."""
    return jnp.array(x[:, None] == jnp.arange(k), dtype)
  
def accuracy(params, images, targets):
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
    return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
    preds = batched_predict(params, images)
    return -jnp.mean(preds * targets)

@jit
def update(params, x, y):
    grads = grad(loss)(params, x, y)
    return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

## Training loop

In [57]:
import time


for epoch in range(num_epochs):
    start_time = time.time()
    for x, y in tfds.as_numpy(data.train_ds):
        x = jnp.reshape(x, (len(x), num_pixels))
        params = update(params, x, y)
    epoch_time = time.time() - start_time

    train_acc = accuracy(params, x_train, y_train)
    test_acc = accuracy(params, x_val, y_val)
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    print("Training set accuracy {}".format(train_acc))
    print("Test set accuracy {}".format(test_acc))

Epoch 0 in 1.91 sec
Training set accuracy 0.5015277862548828
Test set accuracy 0.49944445490837097
Epoch 1 in 0.19 sec
Training set accuracy 0.5045833587646484
Test set accuracy 0.49222221970558167
Epoch 2 in 0.20 sec
Training set accuracy 0.5051388740539551
Test set accuracy 0.49722224473953247
Epoch 3 in 0.21 sec
Training set accuracy 0.5090277791023254
Test set accuracy 0.5088889002799988
Epoch 4 in 0.20 sec
Training set accuracy 0.5179166793823242
Test set accuracy 0.5177778005599976
Epoch 5 in 0.25 sec
Training set accuracy 0.53083336353302
Test set accuracy 0.522777795791626
Epoch 6 in 0.19 sec
Training set accuracy 0.5400000214576721
Test set accuracy 0.5244444608688354
Epoch 7 in 0.24 sec
Training set accuracy 0.5498611330986023
Test set accuracy 0.5250000357627869
Epoch 8 in 0.20 sec
Training set accuracy 0.5601388812065125
Test set accuracy 0.523888885974884
Epoch 9 in 0.20 sec
Training set accuracy 0.5712500214576721
Test set accuracy 0.5261111259460449
Epoch 10 in 0.24 sec
