In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

from importlib.util import find_spec
if find_spec("qml_hep_lhc") is None:
    import sys
    sys.path.append('..')

In [477]:
from qml_hep_lhc.data import ElectronPhoton, MNIST, QuarkGluon
from qml_hep_lhc.data.utils import tf_ds_to_numpy
import argparse
import wandb

import pennylane as qml
import jax.numpy as jnp
import jax
import optax
from jax.nn.initializers import he_uniform
from jax import grad, jit, vmap
from jax import random
import tensorflow_datasets as tfds
from tqdm import tqdm
import numpy as np
import tensorflow as tf

# Added to silence some warnings.
from jax.config import config
config.update("jax_enable_x64", True)

import matplotlib.pyplot as plt
import time

In [3]:
jax.devices()



[CpuDevice(id=0)]

In [746]:
args = argparse.Namespace()

# Data
args.center_crop = 0.2
# args.resize = [8,8]
args.standardize = 1
# args.power_transform = 1
args.binary_data = [3,6]
# args.percent_samples = 0.01
# args.processed = 1
# args.dataset_type = '2'
args.labels_to_categorical = 1
args.batch_size = 128
args.validation_split = 0.1

# Base Model
args.wandb = True
args.epochs = 5
args.learning_rate = 0.008

# Quantum CNN Parameters
args.n_layers = 1
args.n_qubits = 1
args.template = 'NQubitPQCSparse'
args.initializer = 'he_uniform'

args.kernel_size = (3,3)
args.strides = (1,1)
args.padding = "SAME"

args.clayer_sizes = [8, 2]

In [747]:
if args.wandb:
     wandb.init(project='qml-hep-lhc', config = vars(args))

2022-08-18 18:00:01.592936: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory


In [748]:
data = MNIST(args)
data.prepare_data()
data.setup()
print(data)

Binarizing data...
Binarizing data...
Center cropping...
Center cropping...
Resizing data...
Resizing data...
Standardizing data...
Converting labels to categorical...
Converting labels to categorical...

Dataset :MNIST
╒════════╤══════════════════╤═════════════════╤═════════════════╤═══════════╕
│ Data   │ Train size       │ Val size        │ Test size       │ Dims      │
╞════════╪══════════════════╪═════════════════╪═════════════════╪═══════════╡
│ X      │ (10844, 8, 8, 1) │ (1205, 8, 8, 1) │ (1968, 8, 8, 1) │ (8, 8, 1) │
├────────┼──────────────────┼─────────────────┼─────────────────┼───────────┤
│ y      │ (10844, 2)       │ (1205, 2)       │ (1968, 2)       │ (2,)      │
╘════════╧══════════════════╧═════════════════╧═════════════════╧═══════════╛

╒══════════════╤═══════╤═══════╤════════╤═══════╤══════════════════════════╕
│ Type         │   Min │   Max │   Mean │   Std │ Samples for each class   │
╞══════════════╪═══════╪═══════╪════════╪═══════╪══════════════════════════╡
│ 

## Hyperparameters

In [749]:
input_dims = data.config()['input_dims']

In [750]:
def get_out_shape(in_shape, k, s, padding):
    in_shape = (1,) + in_shape
    a = np.random.uniform(size = (in_shape))
    dn = jax.lax.conv_dimension_numbers(a.shape, (1,1,k[0],k[1]), ('NHWC', 'IOHW', 'NHWC'))
    out = jax.lax.conv_general_dilated_patches(lhs = a,
                                           filter_shape= k,
                                           window_strides=s,
                                           padding=padding,
                                           dimension_numbers=dn 
                                    )
    return out.shape

In [751]:
initializer = he_uniform()

# Get qlayer sizes
def get_qlayer_sizes(template, n_l, n_q, k_size):
    if template == 'NQubitPQCSparse':
        return {
            'w': (n_l, n_q,3,np.prod(k_size)),
            'b': (n_l,n_q,3,1)
        }
    elif template == 'SimpleDRC':
        return {
            'w': (n_l+1,n_q,3),
            's': (n_l,n_q),
            'b': (n_l,n_q)
        }
    elif template == 'NQubitPQC':
        assert np.prod(k_size)%3 == 0
        return {
            'w': (n_l,n_q,np.prod(k_size)),
            'b': (n_l,n_q,np.prod(k_size))
        }

def random_qlayer_params(size, key, scale=1e-1):
    return initializer(key, size)
    return scale * random.normal(key, size)

def init_qnetwork_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [[random_qlayer_params(size, key) for size, key in zip(sizes.values(), keys)]]
 

# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_clayer_params(m, n, key, scale=1e-1):
    w_key, b_key = random.split(key)
    return initializer(w_key, (n,m)), random.normal(b_key, (n,))
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_clayer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

kernel_size = args.kernel_size
strides = args.strides
padding = args.padding
clayer_sizes = args.clayer_sizes

template = args.template
n_layers = args.n_layers
n_qubits = args.n_qubits


conv_out_shape = get_out_shape(input_dims, kernel_size, strides, padding)
num_pixels = np.prod(conv_out_shape[:-1])*1
qlayer_sizes = get_qlayer_sizes(template, n_layers, n_qubits, kernel_size)
clayer_sizes = [num_pixels] + clayer_sizes

params = []
params += init_qnetwork_params(qlayer_sizes, random.PRNGKey(0))
# params += init_qnetwork_params(qlayer_sizes, random.PRNGKey(1))
params += init_network_params(clayer_sizes, random.PRNGKey(2))

In [752]:
for i in params:
    for j in i:
        print(j.shape, end = ' ')
    print()

(1, 1, 3, 9) (1, 1, 3, 1) 
(8, 64) (8,) 
(2, 8) (2,) 


## QLayers

In [753]:
dev = qml.device('default.qubit.jax', wires=n_qubits)
qubits =list(range(n_qubits))

@jax.jit
@qml.qnode(dev, interface='jax')
def NQubitPQCSparse(inputs, w, b):
    z = jnp.dot(w, jnp.transpose(inputs))+ b

    for q in qubits:
        qml.Hadamard(wires=q)
    
    for l in range(n_layers):
        for q in qubits:
            qml.Rot(z[l,q,0], z[l,q,1], z[l,q,2], wires= q)
        if (l & 1):
            for q0, q1 in zip(qubits[1::2], qubits[2::2] + [qubits[0]]):
                qml.CZ((q0,q1))
        else:
            for q0, q1 in zip(qubits[0::2], qubits[1::2]):
                qml.CZ((q0,q1))
   
    return qml.expval(qml.PauliZ(qubits[-1]))
    return [qml.expval(qml.PauliZ(q)) for q in qubits]

In [754]:
dev = qml.device('default.qubit.jax', wires=n_qubits)
qubits =list(range(n_qubits))

@jax.jit
@qml.qnode(dev, interface='jax')
def NQubitPQC(inputs, w, b):
    steps = inputs.shape[-1]//3
    for q in qubits:
        qml.Hadamard(wires=q)
    
    for l in range(n_layers):
        for q in qubits:
            for i in range(steps):
                z = jnp.transpose(jnp.multiply(inputs[:,3*i:3*i+3],w[l,q,3*i:3*i+3]) + b[l,q,3*i:3*i+3])
                qml.RX(z[0], wires=q)
                qml.RY(z[1], wires=q)
                qml.RZ(z[2], wires=q)
                
        if (l & 1):
            for q0, q1 in zip(qubits[1::2], qubits[2::2] + [qubits[0]]):
                qml.CZ((q0,q1))
        else:
            for q0, q1 in zip(qubits[0::2], qubits[1::2]):
                qml.CZ((q0,q1))

    return qml.expval(qml.PauliZ(qubits[-1]))
    return [qml.expval(qml.PauliZ(q)) for q in qubits]

In [755]:
dev = qml.device('default.qubit.jax', wires=n_qubits)
qubits =list(range(n_qubits))

@jax.jit
@qml.qnode(dev, interface='jax')
def SimpleDRC(inputs, w, s, b):
    for l in range(n_layers):
        x = jnp.transpose(jnp.multiply(s[l],inputs) + b[l]) 
        for q in qubits:
            qml.Rot(*w[l,q], wires = q)
        for q0, q1 in zip(qubits, qubits[1:]):
            qml.CZ((q0, q1))
        if len(qubits) != 2:
            qml.CZ((qubits[0], qubits[-1]))
        for q in qubits:
            qml.RX(x[q], wires=q)
    for q in qubits:
            qml.Rot(*w[n_layers,q], wires = q)
   
    return [qml.expval(qml.PauliZ(q)) for q in qubits]

In [756]:
def get_node(template):
    if template == 'NQubitPQC':
        return NQubitPQC
    elif template == 'SimpleDRC':
        return SimpleDRC
    elif template == 'NQubitPQCSparse':
        return NQubitPQCSparse

In [757]:
def qconv(x, *qweights):
    
    x = jnp.expand_dims(x,axis=0)
    dn = jax.lax.conv_dimension_numbers(x.shape, 
                                        (1,1,kernel_size[0],kernel_size[1]), 
                                        ('NHWC', 'IOHW', 'NHWC'))
    x = jax.lax.conv_general_dilated_patches(lhs = x,
                                               filter_shape= kernel_size,
                                               window_strides=strides,
                                               padding=padding,
                                               dimension_numbers=dn 
                                              )
    iters = x.shape[1:3]
    x = jnp.reshape(x, (-1, np.prod(kernel_size)))
    
    x = get_node(template)(x, *qweights)
    x = jnp.reshape(x, iters + (1,))
    return x

In [758]:
dev = qml.device("default.qubit", wires=n_qubits)
qnode = qml.QNode(get_node(template), dev)

inputs = np.random.uniform(size = (10,np.prod(kernel_size)))
weights = params[0]
drawer = qml.draw(qnode, expansion_strategy="device")
print(drawer(inputs,*weights))

0: ──H──Rot─┤  <Z>


## Auto-Batching Predictions

In [759]:
from jax.scipy.special import logsumexp

def relu(x):
    return jnp.maximum(0, x)

def forward(params, image):
  # per-example predictions
    activations = qconv(image, *params[0])
    activations += image
    res2 = activations.copy()
    activations = relu(activations)
    activations = jnp.reshape(activations, (-1))
    for w, b in params[1:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)

In [760]:
# This works on single examples
random_flattened_image = random.normal(random.PRNGKey(1), input_dims)
random_flattened_image = jnp.floor(random_flattened_image*10)
preds = forward(params,  random_flattened_image)
print(preds)

[  0.        -51.8189508]


In [761]:
# Doesn't work with a batch
random_flattened_images = random.normal(random.PRNGKey(1), (2,)+ input_dims)
random_flattened_images = jnp.floor(random_flattened_images*10)
# try:
#     preds = predict(params, random_flattened_images)
# except TypeError:
#     print('Invalid shapes!')

In [762]:
# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_forward = vmap(forward, in_axes=(None,0))

# `batched_predict` has the same call signature as `predict`
batched_preds = batched_forward(params, random_flattened_images)
print(batched_preds)

[[-1.38863741e-01 -2.04289064e+00]
 [-1.56476390e+01 -1.60072510e-07]]


## Utility and loss functions

In [763]:
auc = tf.keras.metrics.AUC()

def acc_and_auc(y_true, y_pred):
    target_class = jnp.argmax(y_true, axis=1)
    predicted_class = jnp.argmax(y_pred, axis=1)
    auc.update_state(target_class, predicted_class)
    return jnp.mean(predicted_class == target_class), auc.result().numpy()
 

def loss_fn(params, images, targets):
    preds = batched_forward(params, images)
    loss_value = -jnp.mean(preds * targets)
    return loss_value, preds

@jit
def update(opt_state, params, x, y):
    (loss_value, preds), grads = jax.value_and_grad(loss_fn, has_aux=True)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state


def step(params,x,y):
    loss_value, preds = loss_fn(params, x, y)
    acc, auc_value = acc_and_auc(y, preds)
    return loss_value, acc, auc_value

def evaluate(params, ds):
    res = jnp.array([step(params, x,y)  for x,y in tfds.as_numpy(ds)])
    avg_loss = jnp.mean(res[:,0])
    avg_acc = jnp.mean(res[:,1])
    avg_auc = jnp.mean(res[:,2])
    return avg_loss, avg_acc, avg_auc

## Training loop

In [764]:
schedule_fn = optax.linear_schedule(transition_steps=150,
                                    init_value=0.2,
                                    end_value=1e-7,
                                    )
# Defining an optimizer in Jax
# optimizer = optax.adam(learning_rate=schedule_fn)
optimizer = optax.adam(learning_rate=args.learning_rate)
opt_state = optimizer.init(params)

In [765]:
import time

for epoch in range(args.epochs):
    start_time = time.time()
    with tqdm(tfds.as_numpy(data.train_ds), unit="batch") as tepoch:
        for x, y in tepoch:
            tepoch.set_description(f"Epoch {epoch}")
            params, opt_state = update(opt_state, params, x, y)
            
        
    epoch_time = time.time() - start_time

    train_loss, train_acc, train_auc, = evaluate(params, data.train_ds)
    val_loss, val_acc, val_auc = evaluate(params, data.val_ds)
    print('loss: {} - acc: {} - auc: {} - val_loss: {} - val_acc: {} - val_auc: {} - time: {}'.format(train_loss,
                                                                              train_acc,
                                                                             train_auc,
                                                                              val_loss, 
                                                                              val_acc,
                                                                                val_auc,
                                                                              epoch_time))
    
    if args.wandb:
        wandb.log({"accuracy": train_acc, 
                   'auc': train_auc,
                   "val_accuracy": val_acc, 
                   'val_auc': val_auc,
                   'loss':train_loss, 
                   'val_loss':val_loss})


Epoch 0: 100%|████████████████████████████████████████████████████████████████████| 85/85 [00:11<00:00,  7.44batch/s]
Epoch 1:  18%|███████████▊                                                       | 15/85 [00:00<00:00, 146.67batch/s]

loss: 0.04908277076182867 - acc: 0.9702965154367335 - auc: 0.9699392571168787 - val_loss: 0.07561611213827696 - val_acc: 0.9556014180183411 - val_auc: 0.9693044543266297 - time: 11.42253065109253


Epoch 1: 100%|███████████████████████████████████████████████████████████████████| 85/85 [00:00<00:00, 134.33batch/s]
Epoch 2:  18%|███████████▊                                                       | 15/85 [00:00<00:00, 145.30batch/s]

loss: 0.022874314054954666 - acc: 0.9849464514676263 - auc: 0.9735802966005662 - val_loss: 0.05157559178232182 - val_acc: 0.9704451680183411 - val_auc: 0.9763098895549774 - time: 0.6362545490264893


Epoch 2: 100%|███████████████████████████████████████████████████████████████████| 85/85 [00:00<00:00, 133.21batch/s]
Epoch 3:  18%|███████████▊                                                       | 15/85 [00:00<00:00, 148.60batch/s]

loss: 0.013435704132156053 - acc: 0.9914162404396955 - auc: 0.9789238116320441 - val_loss: 0.04050487949415622 - val_acc: 0.9817069590091706 - val_auc: 0.980971109867096 - time: 0.641472578048706


Epoch 3: 100%|███████████████████████████████████████████████████████████████████| 85/85 [00:00<00:00, 132.51batch/s]
Epoch 4:  16%|███████████                                                        | 14/85 [00:00<00:00, 138.97batch/s]

loss: 0.008763287176412666 - acc: 0.9944852941176471 - auc: 0.9826564129661111 - val_loss: 0.033690559114332726 - val_acc: 0.9851562500000001 - val_auc: 0.9840908229351044 - time: 0.6459476947784424


Epoch 4: 100%|███████████████████████████████████████████████████████████████████| 85/85 [00:00<00:00, 132.57batch/s]


loss: 0.006460194101788134 - acc: 0.9956801470588235 - auc: 0.9852193103117102 - val_loss: 0.033750921492150365 - val_acc: 0.9856132090091706 - val_auc: 0.9862243831157684 - time: 0.6447858810424805


In [766]:
evaluate(params, data.test_ds)

(DeviceArray(0.01136829, dtype=float64),
 DeviceArray(0.99169922, dtype=float64),
 DeviceArray(0.98626414, dtype=float64))

In [768]:
if args.wandb:
    x, y = tf_ds_to_numpy(data.test_ds)
    y = y.argmax(axis=1)
    out = batched_forward(params,x)
    preds = out.argmax(axis=1)
    probs = out
    classes = data.mapping

    roc_curve = wandb.sklearn.plot_roc(y, probs, classes)
    confusion_matrix = wandb.sklearn.plot_confusion_matrix(y, preds, classes)

    wandb.log({"roc_curve": roc_curve})
    wandb.log({"confusion_matrix": confusion_matrix})

In [769]:
if args.wandb:
    wandb.finish()

VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.987772…

0,1
accuracy,▁▅▇██
auc,▁▃▅▇█
loss,█▄▂▁▁
val_accuracy,▁▄▇██
val_auc,▁▄▆▇█
val_loss,█▄▂▁▁

0,1
accuracy,0.99568
auc,0.98522
loss,0.00646
val_accuracy,0.98561
val_auc,0.98622
val_loss,0.03375


In [None]:
for i in range(200):
    print(i, schedule_fn(i))