In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

from importlib.util import find_spec
if find_spec("qml_hep_lhc") is None:
    import sys
    sys.path.append('..')

In [2]:
from qml_hep_lhc.data import ElectronPhoton, MNIST, QuarkGluon
from qml_hep_lhc.data.utils import tf_ds_to_numpy
import argparse
import wandb

import pennylane as qml
import jax.numpy as jnp
import jax
import optax
from jax.nn.initializers import he_uniform
from jax import grad, jit, vmap
from jax import random
import tensorflow_datasets as tfds
from tqdm import tqdm
import numpy as np
import tensorflow as tf
import functools

# Added to silence some warnings.
# from jax.config import config
# config.update("jax_enable_x64", True)

import matplotlib.pyplot as plt
import time

In [3]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [4]:
n_devices = jax.local_device_count()
n_devices

8

In [94]:
args = argparse.Namespace()

# Data
args.center_crop = 0.2
# args.resize = [8,8]
args.standardize = 1
# args.power_transform = 1
# args.binary_data = [3,6]
# args.percent_samples = 0.01
# args.processed = 1
args.dataset_type = '1'
args.labels_to_categorical = 1
args.batch_size = 128
args.validation_split = 0.1

# Base Model
args.wandb = False
args.epochs = 10
args.learning_rate = 0.008

# Quantum CNN Parameters
args.n_layers = 1
args.n_qubits = 1
args.template = 'NQubitPQCSparse'
args.initializer = 'he_uniform'

args.kernel_size = (3,3)
args.strides = (1,1)
args.padding = "SAME"

args.clayer_sizes = [8, 2]

In [6]:
if args.wandb:
     wandb.init(project='qml-hep-lhc', config = vars(args))

In [7]:
data = ElectronPhoton(args)
data.prepare_data()
data.setup()
print(data)

Center cropping...
Center cropping...
Standardizing data...
Converting labels to categorical...
Converting labels to categorical...

Dataset :Electron Photon 1
╒════════╤═════════════════╤════════════════╤═════════════════╤═══════════╕
│ Data   │ Train size      │ Val size       │ Test size       │ Dims      │
╞════════╪═════════════════╪════════════════╪═════════════════╪═══════════╡
│ X      │ (8100, 8, 8, 1) │ (900, 8, 8, 1) │ (1000, 8, 8, 1) │ (8, 8, 1) │
├────────┼─────────────────┼────────────────┼─────────────────┼───────────┤
│ y      │ (8100, 2)       │ (900, 2)       │ (1000, 2)       │ (2,)      │
╘════════╧═════════════════╧════════════════╧═════════════════╧═══════════╛

╒══════════════╤═══════╤═══════╤════════╤═══════╤══════════════════════════╕
│ Type         │   Min │   Max │   Mean │   Std │ Samples for each class   │
╞══════════════╪═══════╪═══════╪════════╪═══════╪══════════════════════════╡
│ Train Images │ -2.88 │ 43.39 │  -0    │  0.98 │ [4050, 4050]             │

In [8]:
input_dims = data.config()['input_dims']

In [9]:
x_train, y_train = tf_ds_to_numpy(data.train_ds)
x_val, y_val = tf_ds_to_numpy(data.val_ds)
x_test, y_test = tf_ds_to_numpy(data.val_ds)
del data

In [10]:
# make samples multiple of n_devices
size = x_train.shape[0]
size = (size//n_devices)*n_devices
x_train = x_train[:size]
y_train = y_train[:size]

size = x_val.shape[0]
size = (size//n_devices)*n_devices
x_val = x_val[:size]
y_val = y_val[:size]

size = x_test.shape[0]
size = (size//n_devices)*n_devices
x_test = x_test[:size]
y_test = y_test[:size]

In [11]:
print(x_train.shape,y_train.shape)
print(x_val.shape,y_val.shape)
print(x_test.shape,y_test.shape)

(8096, 8, 8, 1) (8096, 2)
(896, 8, 8, 1) (896, 2)
(896, 8, 8, 1) (896, 2)


In [12]:
def split(arr):
    """Splits the first axis of `arr` evenly across the number of devices."""
    return arr.reshape(n_devices, arr.shape[0] // n_devices, *arr.shape[1:])

In [13]:
# Reshape xs and ys for the pmapped `update()`.
x_train_split = split(x_train)
y_train_split = split(y_train)

x_val_split = split(x_val)
y_val_split = split(y_val)

x_test_split = split(x_test)
y_test_split = split(y_test)

In [14]:
print(x_train_split.shape,y_train_split.shape)
print(x_val_split.shape,y_val_split.shape)
print(x_test_split.shape,y_test_split.shape)

(8, 1012, 8, 8, 1) (8, 1012, 2)
(8, 112, 8, 8, 1) (8, 112, 2)
(8, 112, 8, 8, 1) (8, 112, 2)


## Hyperparameters

In [15]:
def get_out_shape(in_shape, k, s, padding):
    in_shape = (1,) + in_shape
    a = np.random.uniform(size = (in_shape))
    dn = jax.lax.conv_dimension_numbers(a.shape, (1,1,k[0],k[1]), ('NHWC', 'IOHW', 'NHWC'))
    out = jax.lax.conv_general_dilated_patches(lhs = a,
                                           filter_shape= k,
                                           window_strides=s,
                                           padding=padding,
                                           dimension_numbers=dn 
                                    )
    return out.shape

In [16]:
initializer = he_uniform()

# Get qlayer sizes
def get_qlayer_sizes(template, n_l, n_q, k_size):
    if template == 'NQubitPQCSparse':
        return {
            'w': (n_l, n_q,3,np.prod(k_size)),
            'b': (n_l,n_q,3,1)
        }
    elif template == 'SimpleDRC':
        return {
            'w': (n_l+1,n_q,3),
            's': (n_l,n_q),
            'b': (n_l,n_q)
        }
    elif template == 'NQubitPQC':
        assert np.prod(k_size)%3 == 0
        return {
            'w': (n_l,n_q,np.prod(k_size)),
            'b': (n_l,n_q,np.prod(k_size))
        }

def random_qlayer_params(size, key, scale=1e-1):
    return initializer(key, size)
    return scale * random.normal(key, size)

def init_qnetwork_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [[random_qlayer_params(size, key) for size, key in zip(sizes.values(), keys)]]
 

# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_clayer_params(m, n, key, scale=1e-1):
    w_key, b_key = random.split(key)
    return initializer(w_key, (n,m)), random.normal(b_key, (n,))
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_clayer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

kernel_size = args.kernel_size
strides = args.strides
padding = args.padding
clayer_sizes = args.clayer_sizes

template = args.template
n_layers = args.n_layers
n_qubits = args.n_qubits


conv_out_shape = get_out_shape(input_dims, kernel_size, strides, padding)
num_pixels = np.prod(conv_out_shape[:-1])*1
qlayer_sizes = get_qlayer_sizes(template, n_layers, n_qubits, kernel_size)
clayer_sizes = [num_pixels] + clayer_sizes

params = []
params += init_qnetwork_params(qlayer_sizes, random.PRNGKey(0))
# params += init_qnetwork_params(qlayer_sizes, random.PRNGKey(1))
params += init_network_params(clayer_sizes, random.PRNGKey(2))

  rhs = jnp.eye(spatial_size, dtype=lhs.dtype).reshape(filter_shape * 2)


In [17]:
# Replicate params across devices
replicated_params = jax.tree_map(lambda x: jnp.array([x] * n_devices), params)

In [18]:
for i in params:
    for j in i:
        print(j.shape, end = ' ')
    print()

(1, 1, 3, 9) (1, 1, 3, 1) 
(8, 64) (8,) 
(2, 8) (2,) 


In [19]:
for i in replicated_params:
    for j in i:
        print(j.shape, end = ' ')
    print()

(8, 1, 1, 3, 9) (8, 1, 1, 3, 1) 
(8, 8, 64) (8, 8) 
(8, 2, 8) (8, 2) 


## QLayers

In [20]:
dev = qml.device('default.qubit.jax', wires=n_qubits)
qubits =list(range(n_qubits))

@jax.jit
@qml.qnode(dev, interface='jax')
def NQubitPQCSparse(inputs, w, b):
    z = jnp.dot(w, jnp.transpose(inputs))+ b

    for q in qubits:
        qml.Hadamard(wires=q)
    
    for l in range(n_layers):
        for q in qubits:
            qml.Rot(z[l,q,0], z[l,q,1], z[l,q,2], wires= q)
        if (l & 1):
            for q0, q1 in zip(qubits[1::2], qubits[2::2] + [qubits[0]]):
                qml.CZ((q0,q1))
        else:
            for q0, q1 in zip(qubits[0::2], qubits[1::2]):
                qml.CZ((q0,q1))
   
    return qml.expval(qml.PauliZ(qubits[-1]))
    return [qml.expval(qml.PauliZ(q)) for q in qubits]

In [21]:
dev = qml.device('default.qubit.jax', wires=n_qubits)
qubits =list(range(n_qubits))

@jax.jit
@qml.qnode(dev, interface='jax')
def NQubitPQC(inputs, w, b):
    steps = inputs.shape[-1]//3
    for q in qubits:
        qml.Hadamard(wires=q)
    
    for l in range(n_layers):
        for q in qubits:
            for i in range(steps):
                z = jnp.transpose(jnp.multiply(inputs[:,3*i:3*i+3],w[l,q,3*i:3*i+3]) + b[l,q,3*i:3*i+3])
                qml.RX(z[0], wires=q)
                qml.RY(z[1], wires=q)
                qml.RZ(z[2], wires=q)
                
        if (l & 1):
            for q0, q1 in zip(qubits[1::2], qubits[2::2] + [qubits[0]]):
                qml.CZ((q0,q1))
        else:
            for q0, q1 in zip(qubits[0::2], qubits[1::2]):
                qml.CZ((q0,q1))

    return qml.expval(qml.PauliZ(qubits[-1]))
    return [qml.expval(qml.PauliZ(q)) for q in qubits]

In [22]:
dev = qml.device('default.qubit.jax', wires=n_qubits)
qubits =list(range(n_qubits))

@jax.jit
@qml.qnode(dev, interface='jax')
def SimpleDRC(inputs, w, s, b):
    for l in range(n_layers):
        x = jnp.transpose(jnp.multiply(s[l],inputs) + b[l]) 
        for q in qubits:
            qml.Rot(*w[l,q], wires = q)
        for q0, q1 in zip(qubits, qubits[1:]):
            qml.CZ((q0, q1))
        if len(qubits) != 2:
            qml.CZ((qubits[0], qubits[-1]))
        for q in qubits:
            qml.RX(x[q], wires=q)
    for q in qubits:
            qml.Rot(*w[n_layers,q], wires = q)
   
    return [qml.expval(qml.PauliZ(q)) for q in qubits]

In [23]:
def get_node(template):
    if template == 'NQubitPQC':
        return NQubitPQC
    elif template == 'SimpleDRC':
        return SimpleDRC
    elif template == 'NQubitPQCSparse':
        return NQubitPQCSparse

In [24]:
def qconv(x, *qweights):
    
    x = jnp.expand_dims(x,axis=0)
    dn = jax.lax.conv_dimension_numbers(x.shape, 
                                        (1,1,kernel_size[0],kernel_size[1]), 
                                        ('NHWC', 'IOHW', 'NHWC'))
    x = jax.lax.conv_general_dilated_patches(lhs = x,
                                               filter_shape= kernel_size,
                                               window_strides=strides,
                                               padding=padding,
                                               dimension_numbers=dn 
                                              )
    iters = x.shape[1:3]
    x = jnp.reshape(x, (-1, np.prod(kernel_size)))
    
    x = get_node(template)(x, *qweights)
    x = jnp.reshape(x, iters + (1,))
    return x

In [25]:
dev = qml.device("default.qubit", wires=n_qubits)
qnode = qml.QNode(get_node(template), dev)

inputs = np.random.uniform(size = (10,np.prod(kernel_size)))
weights = params[0]
drawer = qml.draw(qnode, expansion_strategy="device")
print(drawer(inputs,*weights))

0: ──H──Rot─┤  <Z>


## Auto-Batching Predictions

In [26]:
from jax.scipy.special import logsumexp

def relu(x):
    return jnp.maximum(0, x)

def forward(params, image):
  # per-example predictions
    activations = qconv(image, *params[0])
    activations += image
    res2 = activations.copy()
    activations = relu(activations)
    activations = jnp.reshape(activations, (-1))
    for w, b in params[1:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)

In [27]:
# This works on single examples
random_flattened_image = random.normal(random.PRNGKey(1), input_dims)
random_flattened_image = jnp.floor(random_flattened_image*10)
preds = forward(params,  random_flattened_image)
print(preds)

[   0.     -154.1885]


In [28]:
# Doesn't work with a batch
random_flattened_images = random.normal(random.PRNGKey(1), (2,)+ input_dims)
random_flattened_images = jnp.floor(random_flattened_images*10)
# try:
#     preds = predict(params, random_flattened_images)
# except TypeError:
#     print('Invalid shapes!')

In [29]:
# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_forward = vmap(forward, in_axes=(None,0))

# `batched_predict` has the same call signature as `predict`
batched_preds = batched_forward(params, random_flattened_images)
print(batched_preds)

[[   0.       -163.02367 ]
 [   0.        -45.519386]]


## Utility and loss functions

In [111]:
auc = tf.keras.metrics.AUC()

def acc(y_true, y_pred):
    target_class = jnp.argmax(y_true, axis=1)
    predicted_class = jnp.argmax(y_pred, axis=1)
    return jnp.mean(predicted_class == target_class)
 

def loss_fn(params, images, targets):
    preds = batched_forward(params, images)
    loss_value = -jnp.mean(preds * targets)
    return loss_value, preds

@jit
def update(params, x, y):
    (loss_value, preds), grads = jax.value_and_grad(loss_fn, has_aux=True)(params, x, y)
    return grads
    

@functools.partial(jax.pmap, axis_name='num_devices')
def train_epoch(params, x, y):
    train_ds_size = x.shape[0]
    steps_per_epoch = train_ds_size // args.batch_size

    perms = jax.random.permutation(jax.random.PRNGKey(0), train_ds_size)
    perms = perms[:steps_per_epoch * args.batch_size]  # skip incomplete batch
    perms = perms.reshape((steps_per_epoch, args.batch_size))

    for perm in perms:
        batch_images = x[perm, ...]
        batch_labels = y[perm, ...]
        grads = update(params, batch_images, batch_labels)
        
        # Combine the gradient across all devices (by taking their mean).
        grads = jax.lax.pmean(grads, axis_name='num_devices')

        # Each device performs its own update, but since we start with the same params
        # and synchronise gradients, the params stay in sync.
        params = jax.tree_map(lambda param, g: param - g * args.learning_rate, params, grads)

    return params


def step(params,x,y):
    loss_value, preds = loss_fn(params, x, y)
    a = acc(y, preds)
    return loss_value, a

@functools.partial(jax.pmap, axis_name='nd')
def evaluate(params, x,y):
    loss_value, a = step(params,x,y)
    avg_loss = jax.lax.pmean(loss_value, axis_name = 'nd')
    avg_acc = jax.lax.pmean(a, axis_name = 'nd')
    return avg_loss, avg_acc

## Training loop

In [112]:
schedule_fn = optax.linear_schedule(transition_steps=150,
                                    init_value=0.2,
                                    end_value=1e-7,
                                    )
# Defining an optimizer in Jax
# optimizer = optax.adam(learning_rate=schedule_fn)
optimizer = optax.adam(learning_rate=args.learning_rate)
opt_state = optimizer.init(replicated_params)

In [113]:
import time

for epoch in range(10):
    start_time = time.time()
    replicated_params = train_epoch(replicated_params, x_train_split, y_train_split)
    
    epoch_time = time.time() - start_time

    train_loss, train_acc = evaluate(replicated_params, x_train_split, y_train_split)
    val_loss, val_acc = evaluate(replicated_params, x_val_split, y_val_split)
    print('epoch: {} - loss: {} - acc: {} - val_loss: {} - val_acc: {} - time: {}\n'.format(epoch, train_loss[0],
                                                                              train_acc[0],
                                                                              val_loss[0], 
                                                                              val_acc[0],
                                                                              epoch_time))
    
    if args.wandb:
        wandb.log({"accuracy": train_acc, 
                   "val_accuracy": val_acc, 
                   'loss':train_loss, 
                   'val_loss':val_loss})


epoch: 0 - loss: 1.026045799255371 - acc: 0.5281621217727661 - val_loss: 1.0353219509124756 - val_acc: 0.5223214626312256 - time: 26.27303194999695

epoch: 1 - loss: 0.9971594214439392 - acc: 0.5274209976196289 - val_loss: 1.005469799041748 - val_acc: 0.5189732313156128 - time: 0.002466440200805664

epoch: 2 - loss: 0.962954044342041 - acc: 0.5293972492218018 - val_loss: 0.9682097434997559 - val_acc: 0.5178571939468384 - time: 0.0014810562133789062

epoch: 3 - loss: 0.9388502240180969 - acc: 0.5277915000915527 - val_loss: 0.941805362701416 - val_acc: 0.5178571939468384 - time: 0.0014755725860595703

epoch: 4 - loss: 0.914470374584198 - acc: 0.5281620621681213 - val_loss: 0.9157799482345581 - val_acc: 0.5189732313156128 - time: 0.001399993896484375

epoch: 5 - loss: 0.8957167863845825 - acc: 0.5290266871452332 - val_loss: 0.895794689655304 - val_acc: 0.5167410969734192 - time: 0.0014042854309082031

epoch: 6 - loss: 0.8753881454467773 - acc: 0.5287796258926392 - val_loss: 0.873018622398

In [24]:
evaluate(params, data.test_ds)

(DeviceArray(0.34598082, dtype=float32),
 DeviceArray(0.5691106, dtype=float32),
 DeviceArray(0.5751912, dtype=float32))

In [768]:
if args.wandb:
    x, y = tf_ds_to_numpy(data.test_ds)
    y = y.argmax(axis=1)
    out = batched_forward(params,x)
    preds = out.argmax(axis=1)
    probs = out
    classes = data.mapping

    roc_curve = wandb.sklearn.plot_roc(y, probs, classes)
    confusion_matrix = wandb.sklearn.plot_confusion_matrix(y, preds, classes)

    wandb.log({"roc_curve": roc_curve})
    wandb.log({"confusion_matrix": confusion_matrix})

In [769]:
if args.wandb:
    wandb.finish()

VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.987772…

0,1
accuracy,▁▅▇██
auc,▁▃▅▇█
loss,█▄▂▁▁
val_accuracy,▁▄▇██
val_auc,▁▄▆▇█
val_loss,█▄▂▁▁

0,1
accuracy,0.99568
auc,0.98522
loss,0.00646
val_accuracy,0.98561
val_auc,0.98622
val_loss,0.03375


In [None]:
for i in range(200):
    print(i, schedule_fn(i))