In [1]:
from importlib.util import find_spec
%load_ext autoreload
%autoreload 2

%matplotlib inline

if find_spec("qml_hep_lhc") is None:
    import sys
    sys.path.append('../..')

In [2]:
from qml_hep_lhc.data import ElectronPhoton, MNIST, QuarkGluon
from qml_hep_lhc.data.utils import tf_ds_to_numpy
import argparse
import wandb

import pennylane as qml
import jax.numpy as jnp
import jax
import optax
from jax.nn.initializers import he_uniform, glorot_uniform
from jax import grad, jit, vmap
from jax import random
import flax.linen as nn
import tensorflow_datasets as tfds
from tqdm import tqdm
import numpy as np
import tensorflow as tf


import matplotlib.pyplot as plt
import time

2022-09-25 11:58:34.283774: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-09-25 11:58:34.283856: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


In [3]:
jax.devices()



[CpuDevice(id=0)]

In [70]:
args = argparse.Namespace()

# Data
args.center_crop = 0.7
args.resize = [8,8]
args.standardize = 1
args.binary_data = [0,1]
args.batch_size = 1024
args.validation_split = 0.05
args.labels_to_categorical = 1

# Base Model
args.wandb = False

# Quantum CNN Parameters
args.n_layers = 1
args.n_qubits = 1
args.template = 'NQubitPQCSparse'
args.initializer = 'he_uniform'
args.opt = 'adam'

args.num_qconv_layers = 2
args.qconv_dims = [2, 2]
args.kernel_sizes = [(3, 3), (3, 3)]
args.strides = [(1, 1), (1, 1)]
args.paddings = ["SAME", "SAME"]

args.clayer_sizes = [8, 2]

In [71]:
if args.wandb:
     wandb.init(project='qml-hep-lhc', config = vars(args))

In [72]:
data = MNIST(args)
data.prepare_data()
data.setup()
print(data)

Binarizing data...
Binarizing data...
Center cropping...
Center cropping...
Resizing data...
Resizing data...
Standardizing data...
Converting labels to categorical...
Converting labels to categorical...

Dataset :MNIST
╒════════╤══════════════════╤════════════════╤═════════════════╤═══════════╕
│ Data   │ Train size       │ Val size       │ Test size       │ Dims      │
╞════════╪══════════════════╪════════════════╪═════════════════╪═══════════╡
│ X      │ (12031, 8, 8, 1) │ (634, 8, 8, 1) │ (2115, 8, 8, 1) │ (8, 8, 1) │
├────────┼──────────────────┼────────────────┼─────────────────┼───────────┤
│ y      │ (12031, 2)       │ (634, 2)       │ (2115, 2)       │ (2,)      │
╘════════╧══════════════════╧════════════════╧═════════════════╧═══════════╛

╒══════════════╤═══════╤═══════╤════════╤═══════╤══════════════════════════╕
│ Type         │   Min │   Max │   Mean │   Std │ Samples for each class   │
╞══════════════╪═══════╪═══════╪════════╪═══════╪══════════════════════════╡
│ Train I

## Hyperparameters

In [73]:
input_dims = data.config()['input_dims']
input_dims

(8, 8, 1)

In [74]:
def get_out_shape(in_shape,f, k, s, padding):
    in_shape = (1,) + in_shape
    a = np.random.uniform(size = (in_shape))
    dn = jax.lax.conv_dimension_numbers(a.shape, (1,1,k[0],k[1]), ('NHWC', 'IOHW', 'NHWC'))
    out = jax.lax.conv_general_dilated_patches(lhs = a,
                                           filter_shape= k,
                                           window_strides=s,
                                           padding=padding,
                                           dimension_numbers=dn 
                                    )
    return out.shape[1:3]+(f,)

In [75]:
# Get qlayer sizes
def get_qlayer_sizes(template, n_l, n_q, k_size):
    if template == 'NQubitPQCSparse':
        return {
            'w': (n_l, n_q,3,np.prod(k_size)),
            'b': (n_l,n_q,3,1)
        }
    elif template == 'NQubitPQC':
        assert np.prod(k_size)%3 == 0
        return {
            'w': (n_l,n_q,np.prod(k_size)),
            'b': (n_l,n_q,np.prod(k_size))
        }

initializer = he_uniform()

In [76]:
def random_qlayer_params(size, key, filters, n_channels, scale=1e-1):
    w =  initializer(key, size)
    tile_shape = (filters,n_channels,) + (1,)*len(size)
    w = jnp.tile(w, tile_shape)
    return w

def init_qnetwork_params(in_shape, filters, kernel_size, strides, padding, template, n_l, n_q, key):
    n_channels = in_shape[-1]
    sizes = get_qlayer_sizes(template, n_l,n_q, kernel_size)
    keys = random.split(key, len(sizes))
    return [random_qlayer_params(size, key, filters, n_channels) for size, key in zip(sizes.values(), keys)]

In [77]:
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_clayer_params(m, n, key, scale=1e-1):
    w_key, b_key = random.split(key)
    return initializer(w_key, (n,m)), random.normal(b_key, (n,))
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_clayer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

In [78]:
num_qconv_layers = args.num_qconv_layers
qconv_dims = args.qconv_dims
kernel_sizes = args.kernel_sizes
strides = args.strides
paddings = args.paddings
clayer_sizes = args.clayer_sizes

template = args.template
n_layers = args.n_layers
n_qubits = args.n_qubits

In [79]:
in_shape = input_dims
params = []
for l in range(num_qconv_layers):
    qconv_params = init_qnetwork_params(in_shape, 
                                         qconv_dims[l], 
                                         kernel_sizes[l], 
                                         strides[l], 
                                         paddings[l],
                                         template, 
                                         n_layers,
                                         n_qubits,
                                         random.PRNGKey(l))
    params += [qconv_params]
    in_shape = get_out_shape(in_shape,qconv_dims[l],kernel_sizes[l],strides[l],paddings[l])
    print(in_shape)
    

num_pixels = np.prod(in_shape)
clayer_sizes = [num_pixels] + clayer_sizes

params += init_network_params(clayer_sizes, random.PRNGKey(2))

  rhs = jnp.eye(spatial_size, dtype=lhs.dtype).reshape(filter_shape * 2)


(8, 8, 2)
(8, 8, 2)


In [80]:
for i in params:
    for j in i:
        print(j.shape, end = ' ')
    print()

(2, 1, 1, 1, 3, 9) (2, 1, 1, 1, 3, 1) 
(2, 2, 1, 1, 3, 9) (2, 2, 1, 1, 3, 1) 
(8, 128) (8,) 
(2, 8) (2,) 


## QLayers

In [81]:
dev = qml.device('default.qubit.jax', wires=n_qubits)
qubits =list(range(n_qubits))

@jax.jit
@qml.qnode(dev, interface='jax')
def NQubitPQCSparse(inputs, w, b):

    z = jnp.dot(w, jnp.transpose(inputs))+ b

    for q in qubits:
        qml.Hadamard(wires=q)
    
    for l in range(n_layers):
        for q in qubits:
            qml.Rot(z[l,q,0], z[l,q,1], z[l,q,2], wires= q)
        if (l & 1):
            for q0, q1 in zip(qubits[1::2], qubits[2::2] + [qubits[0]]):
                qml.CZ((q0,q1))
        else:
            for q0, q1 in zip(qubits[0::2], qubits[1::2]):
                qml.CZ((q0,q1))
   
    return qml.expval(qml.PauliZ(qubits[-1]))

In [82]:
dev = qml.device('default.qubit.jax', wires=n_qubits)
qubits =list(range(n_qubits))

@jax.jit
@qml.qnode(dev, interface='jax')
def NQubitPQC(inputs, w, b):
    steps = inputs.shape[-1]//3
    for q in qubits:
        qml.Hadamard(wires=q)
    
    for l in range(n_layers):
        for q in qubits:
            for i in range(steps):
                z = jnp.transpose(jnp.multiply(inputs[:,3*i:3*i+3],w[l,q,3*i:3*i+3]) + b[l,q,3*i:3*i+3])
                qml.RZ(z[0], wires=q)
                qml.RY(z[1], wires=q)
                qml.RZ(z[2], wires=q)
                
        if (l & 1):
            for q0, q1 in zip(qubits[1::2], qubits[2::2] + [qubits[0]]):
                qml.CZ((q0,q1))
        else:
            for q0, q1 in zip(qubits[0::2], qubits[1::2]):
                qml.CZ((q0,q1))

    return qml.expval(qml.PauliZ(qubits[-1]))

In [83]:
def get_node(template):
    if template == 'NQubitPQC':
        return NQubitPQC
    elif template == 'NQubitPQCSparse':
        return NQubitPQCSparse

In [84]:
def qconv_cop(x, w,b):
    end_dim = x.shape[-1]
    iters = x.shape[1:3]
    x = jnp.reshape(x , (-1,)+ (end_dim,))
    x = get_node(template)(x, w, b)
    x = jnp.reshape(x, (-1,) + iters)
    return x

batched_qconv_cop = vmap(qconv_cop, in_axes=(3, 0, 0))

def qconv_fop(x, w, b):
    x = batched_qconv_cop(x,w,b)
    x = jnp.sum(x, axis= 0)
    return x

batched_qconv_fop = vmap(qconv_fop, in_axes=(None,0,0))

def qconv(x, params, filters, kernel_size, stride, padding):
    n_channels = x.shape[-1]
    x = jnp.expand_dims(x,axis=0)
    dn = jax.lax.conv_dimension_numbers(x.shape, 
                                        (1,1,kernel_size[0],kernel_size[1]), 
                                        ('NHWC', 'IOHW', 'NHWC'))
    x = jax.lax.conv_general_dilated_patches(lhs = x,
                                               filter_shape= kernel_size,
                                               window_strides=stride,
                                               padding=padding,
                                               dimension_numbers=dn 
                                              )
    iters = x.shape[1:3]
    x = jnp.reshape(x, ((-1,) + iters + (n_channels,) + (np.prod(kernel_size),)))
    x = batched_qconv_fop(x, params[0], params[1])
    x = jnp.reshape(x, iters + (filters,))
    return x

In [85]:
random_flattened_image = random.normal(random.PRNGKey(1), input_dims)
random_flattened_image = jnp.floor(random_flattened_image*10)
random_flattened_image.shape

(8, 8, 1)

In [86]:
out = random_flattened_image
for l in range(num_qconv_layers):
    out = qconv(out, 
                params[l],
                qconv_dims[l], 
                kernel_sizes[l], 
                strides[l], 
                paddings[l])
    print(out.shape)

(8, 8, 2)
(8, 8, 2)


In [87]:
dev = qml.device("default.qubit", wires=n_qubits)
qnode = qml.QNode(get_node(template), dev)

inputs = np.random.uniform(size = (10,np.prod(kernel_sizes[0])))
weights = [params[0][0][0][0], params[0][1][0][0]]
drawer = qml.draw(qnode, expansion_strategy="device")
print(drawer(inputs,*weights))

0: ──H──Rot─┤  <Z>


## Auto-Batching Predictions

In [88]:
from jax.scipy.special import logsumexp

def relu(x):
    return jnp.maximum(0, x)

def forward(params, image):
    # per-example predictions
    
    activations = image
    for l in range(num_qconv_layers):
        activations = qconv(activations, params[l], qconv_dims[l], kernel_sizes[l], strides[l], paddings[l])
        
    activations += image
    activations = relu(activations)
        
    activations = jnp.reshape(activations, (-1))
    
    for w, b in params[num_qconv_layers:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    
    return logits - logsumexp(logits)

In [89]:
# This works on single examples
preds = forward(params,  random_flattened_image)
print(preds)

[-50.780647   0.      ]


In [90]:
# Doesn't work with a batch
random_flattened_images = random.normal(random.PRNGKey(1), (2,)+ input_dims)
random_flattened_images = jnp.floor(random_flattened_images*10)
try:
    preds = forward(params, random_flattened_images)
except TypeError:
    print('Invalid shapes!')

Invalid shapes!


In [91]:
# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_forward = vmap(forward, in_axes=(None,0))

# `batched_predict` has the same call signature as `predict`
batched_preds = batched_forward(params, random_flattened_images)
print(batched_preds)

[[   0.      -130.99347]
 [   0.       -80.78039]]


## Utility and loss functions

In [92]:
from sklearn.metrics import roc_auc_score

def accuracy(y_true, y_pred):
    target_class = jnp.argmax(y_true, axis=1)
    predicted_class = jnp.argmax(y_pred, axis=1)
    return jnp.mean(predicted_class == target_class)

def loss_fn(params, images, targets):
    preds = batched_forward(params, images)
    loss_value = -jnp.mean(preds * targets)
    return loss_value, preds

@jit
def update(opt_state, params, x, y):
    _ , grads = jax.value_and_grad(loss_fn, has_aux=True)(params, x, y)

    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    return params, opt_state 

@jit
def step(params,x,y):
    loss_value, preds = loss_fn(params, x, y)
    acc = accuracy(y, preds)
    return loss_value, acc

def evaluate(params, ds):
    losses = []
    accs = []
    with tqdm(tfds.as_numpy(ds), unit="batch") as tepoch:
        for x, y in tepoch:
            tepoch.set_description("Validation")
            loss_value, acc = step(params, x, y)
            losses.append(loss_value)
            accs.append(acc)
       
    return jnp.mean(np.array(losses)), jnp.mean(np.array(accs))

def predict(params, ds):
    preds = []
    y_true = []
    with tqdm(tfds.as_numpy(ds), unit="batch") as tepoch:
        for x, y in tepoch:
            preds += list(batched_forward(params, x))
            y_true += list(y)
    
    return np.array(preds), np.array(y_true)

## Training loop

In [93]:
import time

epochs = 10
N = 1
lr = 5e-3

epoch_times = []

for i in range(N):
    
    print('Learning Rate:', lr)

    optimizer = optax.adam(learning_rate=lr)
    opt_state = optimizer.init(params)
    
    for epoch in range(epochs):
        start_time = time.time()

        with tqdm(tfds.as_numpy(data.train_ds), unit="batch") as tepoch:
            for x, y in tepoch:
                tepoch.set_description(f"Epoch {epoch}")
                params, opt_state = update(opt_state, params, x, y)

        epoch_time = time.time() - start_time
        epoch_times.append(epoch_time)

        loss, acc = evaluate(params, data.train_ds)
        val_loss, val_acc = evaluate(params, data.val_ds)

        print('loss: {} - acc: {}'.format(loss, acc))
        print('val_loss: {} - val_acc: {}'.format(val_loss, val_acc))
        print('time: {}'.format(epoch_time))

        if args.wandb:
            wandb.log({"accuracy": acc, 
                       "val_accuracy": val_acc, 
                       'loss':loss, 
                       'lr':lr,
                       'val_loss':val_loss})

    lr = lr*np.sqrt(0.1)

Learning Rate: 0.005


Epoch 0: 100%|████████████████████████████████████████████████████████████████████| 12/12 [00:40<00:00,  3.41s/batch]
Validation: 100%|█████████████████████████████████████████████████████████████████| 12/12 [00:24<00:00,  2.01s/batch]
Validation: 100%|███████████████████████████████████████████████████████████████████| 1/1 [00:11<00:00, 11.18s/batch]


loss: 0.29725784063339233 - acc: 0.8465805053710938
val_loss: 0.2740291953086853 - val_acc: 0.8943217992782593
time: 40.946343421936035


Epoch 1: 100%|████████████████████████████████████████████████████████████████████| 12/12 [00:01<00:00,  6.65batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 20.31batch/s]
Validation: 100%|███████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 22.01batch/s]


loss: 0.1250562071800232 - acc: 0.942348837852478
val_loss: 0.149129718542099 - val_acc: 0.9511041045188904
time: 1.8112201690673828


Epoch 2: 100%|████████████████████████████████████████████████████████████████████| 12/12 [00:01<00:00,  7.04batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 13.05batch/s]
Validation: 100%|███████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 19.59batch/s]


loss: 0.046767085790634155 - acc: 0.9776724576950073
val_loss: 0.0577176995575428 - val_acc: 0.9858044385910034
time: 1.718628168106079


Epoch 3: 100%|████████████████████████████████████████████████████████████████████| 12/12 [00:02<00:00,  4.69batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 17.84batch/s]
Validation: 100%|███████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 22.98batch/s]


loss: 0.022341733798384666 - acc: 0.9884427785873413
val_loss: 0.04186905920505524 - val_acc: 0.9858044385910034
time: 2.5704989433288574


Epoch 4: 100%|████████████████████████████████████████████████████████████████████| 12/12 [00:01<00:00,  6.62batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 15.47batch/s]
Validation: 100%|███████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 14.07batch/s]


loss: 0.012991329655051231 - acc: 0.9934887290000916
val_loss: 0.03464209660887718 - val_acc: 0.9952681660652161
time: 1.8176608085632324


Epoch 5: 100%|████████████████████████████████████████████████████████████████████| 12/12 [00:02<00:00,  5.90batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 22.48batch/s]
Validation: 100%|███████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 28.87batch/s]


loss: 0.009302498772740364 - acc: 0.9952249526977539
val_loss: 0.036226365715265274 - val_acc: 0.9952681660652161
time: 2.0422921180725098


Epoch 6: 100%|████████████████████████████████████████████████████████████████████| 12/12 [00:01<00:00, 10.15batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 24.29batch/s]
Validation: 100%|███████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 30.86batch/s]


loss: 0.007727194577455521 - acc: 0.9961475133895874
val_loss: 0.03410317003726959 - val_acc: 0.9952681660652161
time: 1.1867685317993164


Epoch 7: 100%|████████████████████████████████████████████████████████████████████| 12/12 [00:01<00:00,  9.03batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 23.55batch/s]
Validation: 100%|███████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 26.99batch/s]


loss: 0.006787057500332594 - acc: 0.9965001940727234
val_loss: 0.031796667724847794 - val_acc: 0.9952681660652161
time: 1.3335788249969482


Epoch 8: 100%|████████████████████████████████████████████████████████████████████| 12/12 [00:01<00:00,  9.60batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 25.23batch/s]
Validation: 100%|███████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 25.70batch/s]


loss: 0.006115609314292669 - acc: 0.9965815544128418
val_loss: 0.030482735484838486 - val_acc: 0.9952681660652161
time: 1.2561817169189453


Epoch 9: 100%|████████████████████████████████████████████████████████████████████| 12/12 [00:01<00:00,  8.74batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 25.61batch/s]
Validation: 100%|███████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 27.52batch/s]

loss: 0.005556419491767883 - acc: 0.9966902732849121
val_loss: 0.029745331034064293 - val_acc: 0.9968454241752625
time: 1.377626657485962





In [95]:
test_loss, test_acc = evaluate(params, data.test_ds)
test_loss, test_acc

Validation: 100%|███████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 31.70batch/s]


(DeviceArray(0.00501728, dtype=float32),
 DeviceArray(0.99837244, dtype=float32))

In [96]:
from sklearn.metrics import roc_auc_score

out,y_train = predict(params, data.train_ds)
train_auc = roc_auc_score(y_train, out)
train_auc

100%|█████████████████████████████████████████████████████████████████████████████| 12/12 [00:22<00:00,  1.87s/batch]


0.99989940204576

In [97]:
from sklearn.metrics import roc_auc_score

out,y_test = predict(params, data.test_ds)
test_auc = roc_auc_score(y_test, out)
test_auc

100%|███████████████████████████████████████████████████████████████████████████████| 3/3 [00:10<00:00,  3.40s/batch]


0.9998678414096915

In [98]:
if args.wandb:
    wandb.run.summary['test_loss'] = test_loss
    wandb.run.summary['test_acc'] = test_acc
    wandb.run.summary['test_auc'] = test_auc
    wandb.run.summary['train_auc'] = train_auc
    wandb.run.summary['avg_epoch_time'] = np.mean(np.array(epoch_times))
    y = y_test.argmax(axis=1)
    preds = out.argmax(axis=1)
    probs = out
    classes = data.mapping

    roc_curve = wandb.sklearn.plot_roc(y, probs, classes)
    confusion_matrix = wandb.sklearn.plot_confusion_matrix(y, preds, classes)

    wandb.log({"roc_curve": roc_curve})
    wandb.log({"confusion_matrix": confusion_matrix})
    wandb.finish()