In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

from importlib.util import find_spec
if find_spec("qml_hep_lhc") is None:
    import sys
    sys.path.append('..')

In [2]:
from qml_hep_lhc.data import ElectronPhoton, MNIST, QuarkGluon
from qml_hep_lhc.data.utils import tf_ds_to_numpy
import argparse
import wandb

import pennylane as qml
import jax.numpy as jnp
import jax
import optax
from jax.nn.initializers import he_uniform
from jax import grad, jit, vmap
from jax import random
import tensorflow_datasets as tfds
from tqdm import tqdm
import numpy as np
import tensorflow as tf
import functools

# Added to silence some warnings.
# from jax.config import config
# config.update("jax_enable_x64", True)

import matplotlib.pyplot as plt
import time

In [3]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [4]:
n_devices = jax.local_device_count()
n_devices

8

In [90]:
args = argparse.Namespace()

# Data
args.center_crop = 0.2
# args.resize = [8,8]
args.standardize = 1
# args.power_transform = 1
# args.binary_data = [3,6]
# args.percent_samples = 0.01
# args.processed = 1
args.dataset_type = '3'
args.labels_to_categorical = 1
args.batch_size = 128
args.validation_split = 0.1

# Base Model
args.wandb = False
args.epochs = 10
args.learning_rate = 0.001

# Quantum CNN Parameters
args.n_layers = 1
args.n_qubits = 1
args.template = 'NQubitPQCSparse'
args.initializer = 'he_uniform'

args.kernel_size = (3,3)
args.strides = (1,1)
args.padding = "SAME"

args.clayer_sizes = [8, 2]

In [91]:
if args.wandb:
     wandb.init(project='qml-hep-lhc', config = vars(args))

In [92]:
data = ElectronPhoton(args)
data.prepare_data()
data.setup()
print(data)

tcmalloc: large alloc 1638400000 bytes == 0x10779a000 @  0x7f58ea343680 0x7f58ea364824 0x7f58dfa7d064 0x7f58dfa7d7ff 0x7f58dfadbfc5 0x7f58dfa80d08 0x5f73e3 0x57164c 0x569dba 0x5f6eb3 0x56cc1f 0x5f6cd6 0x59e95f 0x5139cc 0x56bf28 0x5f6cd6 0x56bbfa 0x569dba 0x6902a7 0x6023c4 0x5c6730 0x56bacd 0x501488 0x56d4d6 0x501488 0x56d4d6 0x501488 0x505166 0x56bbfa 0x5f6cd6 0x56bacd
tcmalloc: large alloc 1638400000 bytes == 0x18ca18000 @  0x7f58ea343680 0x7f58ea364824 0x7f58dfa7d064 0x7f58dfa7d7ff 0x7f58dfadbfc5 0x7f58dfadc126 0x7f58dfb6e7ea 0x7f58dfb6f24b 0x5139cc 0x56bf28 0x569dba 0x5f6eb3 0x56cc1f 0x569dba 0x5f6eb3 0x56bacd 0x569dba 0x5f6eb3 0x56bacd 0x569dba 0x5f6eb3 0x5f6082 0x56d2d5 0x569dba 0x5f6eb3 0x56bacd 0x5f6cd6 0x56bbfa 0x569dba 0x6902a7 0x6023c4


Center cropping...
Center cropping...
Standardizing data...
Converting labels to categorical...
Converting labels to categorical...

Dataset :Electron Photon 3
╒════════╤═══════════════════╤══════════════════╤══════════════════╤═══════════╕
│ Data   │ Train size        │ Val size         │ Test size        │ Dims      │
╞════════╪═══════════════════╪══════════════════╪══════════════════╪═══════════╡
│ X      │ (360000, 8, 8, 1) │ (40000, 8, 8, 1) │ (98000, 8, 8, 1) │ (8, 8, 1) │
├────────┼───────────────────┼──────────────────┼──────────────────┼───────────┤
│ y      │ (360000, 2)       │ (40000, 2)       │ (98000, 2)       │ (2,)      │
╘════════╧═══════════════════╧══════════════════╧══════════════════╧═══════════╛

╒══════════════╤═══════╤════════╤════════╤═══════╤══════════════════════════╕
│ Type         │   Min │    Max │   Mean │   Std │ Samples for each class   │
╞══════════════╪═══════╪════════╪════════╪═══════╪══════════════════════════╡
│ Train Images │ -2.88 │ 107.37 │     

In [93]:
input_dims = data.config()['input_dims']

In [94]:
x_train, y_train = tf_ds_to_numpy(data.train_ds)
del data.train_ds
x_val, y_val = tf_ds_to_numpy(data.val_ds)
del data.val_ds
x_test, y_test = tf_ds_to_numpy(data.test_ds)
del data.test_ds
del data

In [95]:
# make samples multiple of n_devices
size = x_train.shape[0]
size = (size//n_devices)*n_devices
x_train = x_train[:size]
y_train = y_train[:size]

size = x_val.shape[0]
size = (size//n_devices)*n_devices
x_val = x_val[:size]
y_val = y_val[:size]

size = x_test.shape[0]
size = (size//n_devices)*n_devices
x_test = x_test[:size]
y_test = y_test[:size]

In [96]:
print(x_train.shape,y_train.shape)
print(x_val.shape,y_val.shape)
print(x_test.shape,y_test.shape)

(360000, 8, 8, 1) (360000, 2)
(40000, 8, 8, 1) (40000, 2)
(98000, 8, 8, 1) (98000, 2)


In [97]:
def split(arr):
    """Splits the first axis of `arr` evenly across the number of devices."""
    return arr.reshape(n_devices, arr.shape[0] // n_devices, *arr.shape[1:])

In [98]:
# Reshape xs and ys for the pmapped `update()`.
x_train = split(x_train)
y_train = split(y_train)

x_val = split(x_val)
y_val = split(y_val)

x_test = split(x_test)
y_test = split(y_test)

In [99]:
print(x_train.shape,y_train.shape)
print(x_val.shape,y_val.shape)
print(x_test.shape,y_test.shape)

(8, 45000, 8, 8, 1) (8, 45000, 2)
(8, 5000, 8, 8, 1) (8, 5000, 2)
(8, 12250, 8, 8, 1) (8, 12250, 2)


## Hyperparameters

In [118]:
def get_out_shape(in_shape, k, s, padding):
    in_shape = (1,) + in_shape
    a = np.random.uniform(size = (in_shape))
    dn = jax.lax.conv_dimension_numbers(a.shape, (1,1,k[0],k[1]), ('NHWC', 'IOHW', 'NHWC'))
    out = jax.lax.conv_general_dilated_patches(lhs = a,
                                           filter_shape= k,
                                           window_strides=s,
                                           padding=padding,
                                           dimension_numbers=dn 
                                    )
    return out.shape

In [119]:
initializer = he_uniform()

# Get qlayer sizes
def get_qlayer_sizes(template, n_l, n_q, k_size):
    if template == 'NQubitPQCSparse':
        return {
            'w': (n_l, n_q,3,np.prod(k_size)),
            'b': (n_l,n_q,3,1)
        }
    elif template == 'SimpleDRC':
        return {
            'w': (n_l+1,n_q,3),
            's': (n_l,n_q),
            'b': (n_l,n_q)
        }
    elif template == 'NQubitPQC':
        assert np.prod(k_size)%3 == 0
        return {
            'w': (n_l,n_q,np.prod(k_size)),
            'b': (n_l,n_q,np.prod(k_size))
        }

def random_qlayer_params(size, key, scale=1e-1):
    return initializer(key, size)
    return scale * random.normal(key, size)

def init_qnetwork_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [[random_qlayer_params(size, key) for size, key in zip(sizes.values(), keys)]]
 

# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_clayer_params(m, n, key, scale=1e-1):
    w_key, b_key = random.split(key)
    return initializer(w_key, (n,m)), random.normal(b_key, (n,))
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_clayer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

kernel_size = args.kernel_size
strides = args.strides
padding = args.padding
clayer_sizes = args.clayer_sizes

template = args.template
n_layers = args.n_layers
n_qubits = args.n_qubits


conv_out_shape = get_out_shape(input_dims, kernel_size, strides, padding)
num_pixels = np.prod(conv_out_shape[:-1])*1
qlayer_sizes = get_qlayer_sizes(template, n_layers, n_qubits, kernel_size)
clayer_sizes = [num_pixels] + clayer_sizes

params = []
params += init_qnetwork_params(qlayer_sizes, random.PRNGKey(0))
# params += init_qnetwork_params(qlayer_sizes, random.PRNGKey(1))
params += init_network_params(clayer_sizes, random.PRNGKey(2))

In [120]:
# Replicate params across devices
replicated_params = jax.tree_map(lambda x: jnp.array([x] * n_devices), params)

In [121]:
for i in params:
    for j in i:
        print(j.shape, end = ' ')
    print()

(1, 1, 3, 9) (1, 1, 3, 1) 
(8, 64) (8,) 
(2, 8) (2,) 


In [122]:
for i in replicated_params:
    for j in i:
        print(j.shape, end = ' ')
    print()

(8, 1, 1, 3, 9) (8, 1, 1, 3, 1) 
(8, 8, 64) (8, 8) 
(8, 2, 8) (8, 2) 


## QLayers

In [123]:
dev = qml.device('default.qubit.jax', wires=n_qubits)
qubits =list(range(n_qubits))

@jax.jit
@qml.qnode(dev, interface='jax')
def NQubitPQCSparse(inputs, w, b):
    z = jnp.dot(w, jnp.transpose(inputs))+ b

    for q in qubits:
        qml.Hadamard(wires=q)
    
    for l in range(n_layers):
        for q in qubits:
            qml.Rot(z[l,q,0], z[l,q,1], z[l,q,2], wires= q)
        if (l & 1):
            for q0, q1 in zip(qubits[1::2], qubits[2::2] + [qubits[0]]):
                qml.CZ((q0,q1))
        else:
            for q0, q1 in zip(qubits[0::2], qubits[1::2]):
                qml.CZ((q0,q1))
   
    return qml.expval(qml.PauliZ(qubits[-1]))
    return [qml.expval(qml.PauliZ(q)) for q in qubits]

In [124]:
dev = qml.device('default.qubit.jax', wires=n_qubits)
qubits =list(range(n_qubits))

@jax.jit
@qml.qnode(dev, interface='jax')
def NQubitPQC(inputs, w, b):
    steps = inputs.shape[-1]//3
    for q in qubits:
        qml.Hadamard(wires=q)
    
    for l in range(n_layers):
        for q in qubits:
            for i in range(steps):
                z = jnp.transpose(jnp.multiply(inputs[:,3*i:3*i+3],w[l,q,3*i:3*i+3]) + b[l,q,3*i:3*i+3])
                qml.RX(z[0], wires=q)
                qml.RY(z[1], wires=q)
                qml.RZ(z[2], wires=q)
                
        if (l & 1):
            for q0, q1 in zip(qubits[1::2], qubits[2::2] + [qubits[0]]):
                qml.CZ((q0,q1))
        else:
            for q0, q1 in zip(qubits[0::2], qubits[1::2]):
                qml.CZ((q0,q1))

    return qml.expval(qml.PauliZ(qubits[-1]))
    return [qml.expval(qml.PauliZ(q)) for q in qubits]

In [125]:
dev = qml.device('default.qubit.jax', wires=n_qubits)
qubits =list(range(n_qubits))

@jax.jit
@qml.qnode(dev, interface='jax')
def SimpleDRC(inputs, w, s, b):
    for l in range(n_layers):
        x = jnp.transpose(jnp.multiply(s[l],inputs) + b[l]) 
        for q in qubits:
            qml.Rot(*w[l,q], wires = q)
        for q0, q1 in zip(qubits, qubits[1:]):
            qml.CZ((q0, q1))
        if len(qubits) != 2:
            qml.CZ((qubits[0], qubits[-1]))
        for q in qubits:
            qml.RX(x[q], wires=q)
    for q in qubits:
            qml.Rot(*w[n_layers,q], wires = q)
   
    return [qml.expval(qml.PauliZ(q)) for q in qubits]

In [126]:
def get_node(template):
    if template == 'NQubitPQC':
        return NQubitPQC
    elif template == 'SimpleDRC':
        return SimpleDRC
    elif template == 'NQubitPQCSparse':
        return NQubitPQCSparse

In [127]:
def qconv(x, *qweights):
    
    x = jnp.expand_dims(x,axis=0)
    dn = jax.lax.conv_dimension_numbers(x.shape, 
                                        (1,1,kernel_size[0],kernel_size[1]), 
                                        ('NHWC', 'IOHW', 'NHWC'))
    x = jax.lax.conv_general_dilated_patches(lhs = x,
                                               filter_shape= kernel_size,
                                               window_strides=strides,
                                               padding=padding,
                                               dimension_numbers=dn 
                                              )
    iters = x.shape[1:3]
    x = jnp.reshape(x, (-1, np.prod(kernel_size)))
    
    x = get_node(template)(x, *qweights)
    x = jnp.reshape(x, iters + (1,))
    return x

In [128]:
dev = qml.device("default.qubit", wires=n_qubits)
qnode = qml.QNode(get_node(template), dev)

inputs = np.random.uniform(size = (10,np.prod(kernel_size)))
weights = params[0]
drawer = qml.draw(qnode, expansion_strategy="device")
print(drawer(inputs,*weights))

0: ──H──Rot─┤  <Z>


## Auto-Batching Predictions

In [129]:
from jax.scipy.special import logsumexp

def relu(x):
    return jnp.maximum(0, x)

def forward(params, image):
  # per-example predictions
    activations = qconv(image, *params[0])
    activations += image
    res2 = activations.copy()
    activations = relu(activations)
    activations = jnp.reshape(activations, (-1))
    for w, b in params[1:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)

In [130]:
# This works on single examples
random_flattened_image = random.normal(random.PRNGKey(1), input_dims)
random_flattened_image = jnp.floor(random_flattened_image*10)
preds = forward(params,  random_flattened_image)
print(preds)

[   0.      -162.26166]


In [131]:
# Doesn't work with a batch
random_flattened_images = random.normal(random.PRNGKey(1), (2,)+ input_dims)
random_flattened_images = jnp.floor(random_flattened_images*10)
# try:
#     preds = predict(params, random_flattened_images)
# except TypeError:
#     print('Invalid shapes!')

In [132]:
# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_forward = vmap(forward, in_axes=(None,0))

# `batched_predict` has the same call signature as `predict`
batched_preds = batched_forward(params, random_flattened_images)
print(batched_preds)

[[   0.       -167.37828 ]
 [   0.        -33.438576]]


## Utility and loss functions

In [133]:
def acc(y_true, y_pred):
    target_class = jnp.argmax(y_true, axis=1)
    predicted_class = jnp.argmax(y_pred, axis=1)
    return jnp.mean(predicted_class == target_class)
 

def loss_fn(params, images, targets):
    preds = batched_forward(params, images)
    loss_value = -jnp.mean(preds * targets)
    return loss_value, preds

@jit
def update(params, x, y):
    (loss_value, preds), grads = jax.value_and_grad(loss_fn, has_aux=True)(params, x, y)
    return grads
    

@functools.partial(jax.pmap, axis_name='num_devices')
def train_epoch(params, x, y):
    train_ds_size = x.shape[0]
    steps_per_epoch = train_ds_size // args.batch_size

    perms = jax.random.permutation(jax.random.PRNGKey(0), train_ds_size)
    perms = perms[:steps_per_epoch * args.batch_size]  # skip incomplete batch
    perms = perms.reshape((steps_per_epoch, args.batch_size))
    
    with tqdm(perms, unit="batch") as tepoch:
        for perm in tepoch:
            tepoch.set_description(f"Epoch")
            
            grads = update(params, x[perm, ...], y[perm, ...])

            # Combine the gradient across all devices (by taking their mean).
            grads = jax.lax.pmean(grads, axis_name='num_devices')

            # Each device performs its own update, but since we start with the same params
            # and synchronise gradients, the params stay in sync.
            params = jax.tree_map(lambda param, g: param - g * args.learning_rate, params, grads)

    return params


def step(params,x,y):
    loss_value, preds = loss_fn(params, x, y)
    a = acc(y, preds)
    return loss_value, a

@functools.partial(jax.pmap, axis_name='nd')
def evaluate(params, x,y):
    loss_value, a = step(params,x,y)
    avg_loss = jax.lax.pmean(loss_value, axis_name = 'nd')
    avg_acc = jax.lax.pmean(a, axis_name = 'nd')
    return avg_loss, avg_acc

@functools.partial(jax.pmap, axis_name='nd')
def predict(params, x):
    preds = batched_forward(params, x)
    avg_preds = jax.lax.pmean(preds, axis_name = 'nd')
    return avg_preds

## Training loop

In [134]:
schedule_fn = optax.linear_schedule(transition_steps=150,
                                    init_value=0.2,
                                    end_value=1e-7,
                                    )
# Defining an optimizer in Jax
# optimizer = optax.adam(learning_rate=schedule_fn)
optimizer = optax.adam(learning_rate=args.learning_rate)
opt_state = optimizer.init(replicated_params)

In [None]:
# Run for the first time to move data and params from CPU to TPU

start_time = time.time()
replicated_params = train_epoch(replicated_params, x_train, y_train) 

epoch_time = time.time() - start_time
print(epoch_time)

Epoch: 100%|████████████████████████████████████████████████████████████████████| 351/351 [01:03<00:00,  5.53batch/s]


In [89]:
import time

for epoch in range(args.epochs):
    start_time = time.time()
    replicated_params = train_epoch(replicated_params, x_train, y_train)
    
    epoch_time = time.time() - start_time
    print(epoch_time)

#     train_loss, train_acc = evaluate(replicated_params, x_train, y_train)
#     val_loss, val_acc = evaluate(replicated_params, x_val, y_val)
#     print('epoch: {} - loss: {} - acc: {} - val_loss: {} - val_acc: {} - time: {}\n'.format(epoch, train_loss[0],
#                                                                               train_acc[0],
#                                                                               val_loss[0], 
#                                                                               val_acc[0],
#                                                                               epoch_time))
    
    if args.wandb:
        wandb.log({"accuracy": train_acc[0], 
                   "val_accuracy": val_acc[0], 
                   'loss':train_loss[0], 
                   'val_loss':val_loss[0]})


0.002457141876220703
0.0023679733276367188
0.001874685287475586
0.0016205310821533203
0.0016393661499023438
0.0014712810516357422
0.0015168190002441406
0.0015435218811035156
0.0014514923095703125
0.0017185211181640625


In [39]:
test_loss, test_acc = evaluate(replicated_params, x_test, y_test)
test_loss[0], test_acc[0]

(DeviceArray(1.6530111, dtype=float32), DeviceArray(0.51600003, dtype=float32))

In [50]:
from sklearn.metrics import roc_auc_score

preds = predict(replicated_params, x_test)
roc_auc_score(y_test.reshape(-1,2), preds.reshape(-1,2))

0.48265600000000003

In [768]:
if args.wandb:
    x, y = tf_ds_to_numpy(data.test_ds)
    y = y.argmax(axis=1)
    out = batched_forward(params,x)
    preds = out.argmax(axis=1)
    probs = out
    classes = data.mapping

    roc_curve = wandb.sklearn.plot_roc(y, probs, classes)
    confusion_matrix = wandb.sklearn.plot_confusion_matrix(y, preds, classes)

    wandb.log({"roc_curve": roc_curve})
    wandb.log({"confusion_matrix": confusion_matrix})

In [769]:
if args.wandb:
    wandb.finish()

VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.987772…

0,1
accuracy,▁▅▇██
auc,▁▃▅▇█
loss,█▄▂▁▁
val_accuracy,▁▄▇██
val_auc,▁▄▆▇█
val_loss,█▄▂▁▁

0,1
accuracy,0.99568
auc,0.98522
loss,0.00646
val_accuracy,0.98561
val_auc,0.98622
val_loss,0.03375


In [None]:
for i in range(200):
    print(i, schedule_fn(i))