In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

from importlib.util import find_spec
if find_spec("qml_hep_lhc") is None:
    import sys
    sys.path.append('..')

In [2]:
from qml_hep_lhc.data import ElectronPhoton, MNIST, QuarkGluon
from qml_hep_lhc.data.utils import tf_ds_to_numpy
from qml_hep_lhc.data.preprocessor import DataPreprocessor
from sklearn.utils import shuffle
import argparse
import wandb

import pennylane as qml
from pennylane import numpy as np
from pennylane.optimize import AdamOptimizer, GradientDescentOptimizer
import jax.numpy as jnp
import jax
from flax.training import train_state
from flax import linen as nn
import optax
from absl import logging
from jax import grad, jit, vmap
from jax import random
import tensorflow_datasets as tfds
from tqdm import tqdm
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt
import time

2022-08-16 09:09:09.421609: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-08-16 09:09:09.421696: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


In [398]:
args = argparse.Namespace()

# Data
args.center_crop = 0.2
# args.resize = [8,8]
args.standardize = 1
args.power_transform = 1
# args.binary_data = [0,1]
# args.percent_samples = 0.01
# args.processed = 1
args.dataset_type = 'med'
args.labels_to_categorical = 1
args.batch_size = 128

# Base Model
args.wandb = False
# args.epochs = 30
# args.learning_rate = 0.2

# Quantum CNN Parameters
# args.n_layers = 2
# args.n_qubits = 2
# args.template = 'NQubitPQC'

In [399]:
data = ElectronPhoton(args)
data.prepare_data()
data.setup()
print(data)

Center cropping...
Center cropping...
Performing power transform...
Standardizing data...
Converting labels to categorical...
Converting labels to categorical...

Dataset :Electron Photon med
╒════════╤═════════════════╤═════════════════╤═════════════════╤═══════════╕
│ Data   │ Train size      │ Val size        │ Test size       │ Dims      │
╞════════╪═════════════════╪═════════════════╪═════════════════╪═══════════╡
│ X      │ (7200, 8, 8, 1) │ (1800, 8, 8, 1) │ (1000, 8, 8, 1) │ (8, 8, 1) │
├────────┼─────────────────┼─────────────────┼─────────────────┼───────────┤
│ y      │ (7200, 2)       │ (1800, 2)       │ (1000, 2)       │ (2,)      │
╘════════╧═════════════════╧═════════════════╧═════════════════╧═══════════╛

╒══════════════╤═══════╤═══════╤════════╤═══════╤══════════════════════════╕
│ Type         │   Min │   Max │   Mean │   Std │ Samples for each class   │
╞══════════════╪═══════╪═══════╪════════╪═══════╪══════════════════════════╡
│ Train Images │ -3.05 │  5.53 │   0 

In [400]:
if args.wandb:
     wandb.init(project='qml-hep-lhc')

## Hyperparameters

In [401]:
input_dims = data.config()['input_dims']

In [419]:
def get_out_shape(in_shape, k, s, padding):
    in_shape = (1,) + in_shape
    a = np.random.uniform(size = (in_shape))
    dn = jax.lax.conv_dimension_numbers(a.shape, (1,1,k[0],k[1]), ('NHWC', 'IOHW', 'NHWC'))
    out = jax.lax.conv_general_dilated_patches(lhs = a,
                                           filter_shape= k,
                                           window_strides=s,
                                           padding=padding,
                                           dimension_numbers=dn 
                                    )
    return out.shape

In [440]:
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-1):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]
 

kernel_size = (3,3)
strides = (1,1)
padding = "VALID"
clayer_sizes = [8,2]

step_size = 0.2
num_epochs = 30

n_layers = 2
n_qubits = 4


conv_out_shape = get_out_shape(input_dims, kernel_size, strides, padding)
num_pixels = np.prod(conv_out_shape[:-1])*n_qubits
qlayer_sizes = [np.prod(kernel_size), n_layers*n_qubits*3]
clayer_sizes = [num_pixels] + clayer_sizes

params = init_network_params(qlayer_sizes, random.PRNGKey(0))
params += init_network_params(clayer_sizes, random.PRNGKey(0))

  rhs = jnp.eye(spatial_size, dtype=lhs.dtype).reshape(filter_shape * 2)


In [441]:
for i in params:
    print(i[0].shape, i[1].shape)

(24, 9) (24,)
(8, 144) (8,)
(2, 8) (2,)


## QLayers

In [442]:
dev = qml.device('default.qubit.jax', wires=n_qubits)
qubits =list(range(n_qubits))

@jax.jit
@qml.qnode(dev, interface='jax')
def circuit(inputs, w, b):
    z = jnp.dot(inputs, jnp.transpose(w)) + b
    z = jnp.tanh(z)*jnp.pi/2
    z = jnp.reshape(z, (n_layers, n_qubits, 3, -1))
    
    for q in qubits:
        qml.Hadamard(wires=q)
    
    for l in range(n_layers):
        for q in qubits:
            qml.Rot(z[l,q,0], z[l,q,1], z[l,q,2], wires= q)
        if (l & 1):
            for q0, q1 in zip(qubits[1::2], qubits[2::2] + [qubits[0]]):
                qml.CZ((q0,q1))
        else:
            for q0, q1 in zip(qubits[0::2], qubits[1::2]):
                qml.CZ((q0,q1))
   
    return [qml.expval(qml.PauliX(q)@qml.PauliY(q)@qml.PauliZ(q)) for q in qubits]

In [443]:
def qconv(x, w, b):
    x = jnp.expand_dims(x,axis=0)
        
    dn = jax.lax.conv_dimension_numbers(x.shape, 
                                        (1,1,kernel_size[0],kernel_size[1]), 
                                        ('NHWC', 'IOHW', 'NHWC'))
    x = jax.lax.conv_general_dilated_patches(lhs = x,
                                               filter_shape= kernel_size,
                                               window_strides=strides,
                                               padding=padding,
                                               dimension_numbers=dn 
                                              )
    iters = x.shape[1:3]
    x = jnp.reshape(x, (-1, np.prod(kernel_size)))
    x = circuit(x, w, b)
    x = jnp.reshape(x, iters + (n_qubits,))
    return x

In [444]:
dev = qml.device("default.qubit", wires=n_qubits)
qnode = qml.QNode(circuit, dev)

inputs = np.random.uniform(size = (np.prod(kernel_size),))
w,b  = params[0]
drawer = qml.draw(qnode, expansion_strategy="device")
print(drawer(inputs,w,b))

0: ──H──Rot─╭●──Rot────╭Z─┤  <X@Y@Z>
1: ──H──Rot─╰Z──Rot─╭●─│──┤  <X@Y@Z>
2: ──H──Rot─╭●──Rot─╰Z─│──┤  <X@Y@Z>
3: ──H──Rot─╰Z──Rot────╰●─┤  <X@Y@Z>


## Auto-Batching Predictions

In [445]:
from jax.scipy.special import logsumexp

def relu(x):
    return jnp.maximum(0, x)

def forward(params, image):
  # per-example predictions

    qw, qb = params[0]
    activations = qconv(image, qw, qb)
    activations = jnp.reshape(activations, (-1))
    
    for w, b in params[1:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)

In [446]:
# This works on single examples
random_flattened_image = random.normal(random.PRNGKey(1), input_dims)
preds = forward(params,  random_flattened_image)
print(preds)

[-0.71354556 -0.6731566 ]


In [447]:
# Doesn't work with a batch
random_flattened_images = random.normal(random.PRNGKey(1), (2,)+ input_dims)
# try:
#     preds = predict(params, random_flattened_images)
# except TypeError:
#     print('Invalid shapes!')

In [448]:
# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_forward = vmap(forward, in_axes=(None,0))

# `batched_predict` has the same call signature as `predict`
batched_preds = batched_forward(params, random_flattened_images)
print(batched_preds)

[[-0.7196192 -0.6673578]
 [-0.7208551 -0.6661863]]


## Utility and loss functions

In [449]:
def accuracy(y_true, y_pred):
    target_class = jnp.argmax(y_true, axis=1)
    predicted_class = jnp.argmax(y_pred, axis=1)
    return jnp.mean(predicted_class == target_class)


def loss_fn(params, images, targets):
    preds = batched_forward(params, images)
    loss_value = jnp.mean(optax.softmax_cross_entropy(logits=preds, labels=targets))
    return loss_value, preds

@jit
def update(opt_state, params, x, y):
    (loss_value, preds), grads = jax.value_and_grad(loss_fn, has_aux=True)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state


def step(params,x,y):
    loss_value, preds = loss_fn(params, x, y)
    acc = accuracy(y, preds)
    return loss, acc

def evaluate(params, ds):
    res = jnp.array([step(params, x,y)  for x,y in tfds.as_numpy(ds)])
    avg_loss = jnp.mean(res[:,0])
    avg_acc = jnp.mean(res[:,1])
    return avg_loss, avg_acc

## Training loop

In [452]:
schedule_fn = optax.cosine_onecycle_schedule(transition_steps=100,
                                            peak_value=step_size,
                                            div_factor=5,
                                            )
# Defining an optimizer in Jax
optimizer = optax.adam(learning_rate=0.2)
opt_state = optimizer.init(params)

In [None]:
import time
num_epochs = 20

for epoch in range(num_epochs):
    start_time = time.time()
    with tqdm(tfds.as_numpy(data.train_ds), unit="batch") as tepoch:
        for x, y in tepoch:
            tepoch.set_description(f"Epoch {epoch}")
            params, opt_state = update(opt_state, params, x, y)
        
        epoch_time = time.time() - start_time
        
        train_loss, train_acc = evaluate(params, data.train_ds)
        val_loss, val_acc = evaluate(params, data.val_ds)
        print('loss: {} - acc: {} - val_loss: {} - val_acc: {}'.format(train_loss,train_acc, val_loss, val_acc))
    
    if args.wandb:
        wandb.log({"accuracy": train_acc, "val_accuracy": test_acc, 'loss':train_loss, 'val_loss':val_loss})


Epoch 0: 100%|████████████████████████████████████████████████████████████████████| 57/57 [00:48<00:00,  1.18batch/s]
Epoch 1:  23%|███████████████▎                                                   | 13/57 [00:00<00:00, 119.89batch/s]

loss: 0.669510006904602 - acc: 0.499862939119339 - val_loss: 0.6695100665092468 - val_acc: 0.5078125


Epoch 1: 100%|███████████████████████████████████████████████████████████████████| 57/57 [00:00<00:00, 130.40batch/s]
Epoch 2:  23%|███████████████▎                                                   | 13/57 [00:00<00:00, 120.48batch/s]

loss: 0.669510006904602 - acc: 0.49958881735801697 - val_loss: 0.6695100665092468 - val_acc: 0.499479204416275


Epoch 2: 100%|███████████████████████████████████████████████████████████████████| 57/57 [00:00<00:00, 133.17batch/s]
Epoch 3:  21%|██████████████                                                     | 12/57 [00:00<00:00, 116.83batch/s]

loss: 0.669510006904602 - acc: 0.499725878238678 - val_loss: 0.6695100665092468 - val_acc: 0.5


Epoch 3: 100%|███████████████████████████████████████████████████████████████████| 57/57 [00:00<00:00, 132.54batch/s]
Epoch 4:  19%|████████████▉                                                      | 11/57 [00:00<00:00, 105.03batch/s]

loss: 0.669510006904602 - acc: 0.5004112124443054 - val_loss: 0.6695100665092468 - val_acc: 0.49270835518836975


Epoch 4: 100%|███████████████████████████████████████████████████████████████████| 57/57 [00:00<00:00, 119.19batch/s]
Epoch 5:  21%|██████████████                                                     | 12/57 [00:00<00:00, 113.90batch/s]

loss: 0.669510006904602 - acc: 0.5008223652839661 - val_loss: 0.6695100665092468 - val_acc: 0.5


Epoch 5: 100%|███████████████████████████████████████████████████████████████████| 57/57 [00:00<00:00, 127.10batch/s]
Epoch 6:  21%|██████████████                                                     | 12/57 [00:00<00:00, 113.87batch/s]

loss: 0.669510006904602 - acc: 0.5 - val_loss: 0.6695100665092468 - val_acc: 0.4921875298023224


Epoch 6: 100%|███████████████████████████████████████████████████████████████████| 57/57 [00:00<00:00, 114.90batch/s]
Epoch 7:  23%|███████████████▎                                                   | 13/57 [00:00<00:00, 127.54batch/s]

loss: 0.669510006904602 - acc: 0.4987664520740509 - val_loss: 0.6695100665092468 - val_acc: 0.5


Epoch 7: 100%|███████████████████████████████████████████████████████████████████| 57/57 [00:00<00:00, 138.83batch/s]
Epoch 8:  23%|███████████████▎                                                   | 13/57 [00:00<00:00, 127.30batch/s]

loss: 0.669510006904602 - acc: 0.49958881735801697 - val_loss: 0.6695100665092468 - val_acc: 0.5234375


Epoch 8: 100%|███████████████████████████████████████████████████████████████████| 57/57 [00:00<00:00, 128.14batch/s]
Epoch 9:  23%|███████████████▎                                                   | 13/57 [00:00<00:00, 126.17batch/s]

loss: 0.669510006904602 - acc: 0.5004112124443054 - val_loss: 0.6695100665092468 - val_acc: 0.4843750298023224


Epoch 9: 100%|████████████████████████████████████████████████████████████████████| 57/57 [00:00<00:00, 90.18batch/s]
Epoch 10:  23%|███████████████                                                   | 13/57 [00:00<00:00, 126.67batch/s]

loss: 0.669510006904602 - acc: 0.49958881735801697 - val_loss: 0.6695100665092468 - val_acc: 0.5078125


Epoch 10: 100%|██████████████████████████████████████████████████████████████████| 57/57 [00:00<00:00, 140.93batch/s]


In [415]:
evaluate(params, data.test_ds)

(DeviceArray(0.66951007, dtype=float32),
 DeviceArray(0.48069412, dtype=float32))

In [55]:
if args.wandb:
    wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▆██████▇█████▁███▅▅█
val_accuracy,▆██████▇█████▁███▅▅█

0,1
accuracy,0.9544
val_accuracy,0.9546


In [439]:
for i in range(200):
    print(i, schedule_fn(i))

0 0.03999999
1 0.04043825
2 0.04174818
3 0.04391548
4 0.046916366
5 0.05071795
6 0.05527863
7 0.06054841
8 0.06646955
9 0.07297717
10 0.07999999
11 0.08746106
12 0.095278636
13 0.10336706
14 0.11163772
15 0.12
16 0.12836227
17 0.13663292
18 0.14472136
19 0.15253893
20 0.16
21 0.1670228
22 0.17353044
23 0.17945157
24 0.18472135
25 0.18928201
26 0.19308363
27 0.1960845
28 0.1982518
29 0.19956174
30 0.19999999
31 0.1998993
32 0.19959743
33 0.19909498
34 0.19839299
35 0.19749282
36 0.19639634
37 0.19510573
38 0.19362362
39 0.19195293
40 0.19009706
41 0.18805978
42 0.18584515
43 0.18345764
44 0.18090206
45 0.17818357
46 0.17530763
47 0.17228001
48 0.16910687
49 0.16579455
50 0.16234972
51 0.15877934
52 0.15509057
53 0.15129088
54 0.1473879
55 0.1433895
56 0.13930371
57 0.13513878
58 0.13090307
59 0.12660514
60 0.122253634
61 0.11785732
62 0.11342505
63 0.10896575
64 0.10448839
65 0.10000199
66 0.09551559
67 0.09103824
68 0.08657893
69 0.08214667
70 0.07775034
71 0.07339884
72 0.0691009
73 0