In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

from importlib.util import find_spec
if find_spec("qml_hep_lhc") is None:
    import sys
    sys.path.append('..')

In [199]:
from qml_hep_lhc.data import ElectronPhoton, MNIST, QuarkGluon
from qml_hep_lhc.data.utils import tf_ds_to_numpy
import argparse
import wandb

import pennylane as qml
import jax.numpy as jnp
import jax
import optax
from jax.nn.initializers import he_uniform
from jax import grad, jit, vmap
from jax import random
import tensorflow_datasets as tfds
from tqdm import tqdm
import numpy as np
import tensorflow as tf

# Added to silence some warnings.
# from jax.config import config
# config.update("jax_enable_x64", True)

from qiskit.circuit.library import EfficientSU2, PauliFeatureMap
from qiskit import QuantumCircuit
from qiskit.circuit import ParameterVector, Parameter

import matplotlib.pyplot as plt
import time

In [8]:
jax.devices()

[CpuDevice(id=0)]

In [316]:
args = argparse.Namespace()

# Data
args.center_crop = 0.2
# args.resize = [8,8]
args.standardize = 1
# args.power_transform = 1
# args.binary_data = [3,6]
# args.percent_samples = 0.01
# args.processed = 1
args.dataset_type = '1'
args.labels_to_categorical = 1
args.batch_size = 128
args.validation_split = 0.1

# Base Model
args.wandb = False
args.epochs = 10
args.learning_rate = 0.005

# Quantum CNN Parameters
args.n_layers = 1
args.n_elayers = 1
args.n_qubits = 9
args.template = 'NQubitPQCSparse'
args.initializer = 'he_uniform'

args.kernel_size = (3,3)
args.strides = (1,1)
args.padding = "SAME"

args.clayer_sizes = [8, 2]

In [317]:
if args.wandb:
     wandb.init(project='qml-hep-lhc', config = vars(args))

In [318]:
data = ElectronPhoton(args)
data.prepare_data()
data.setup()
print(data)

Center cropping...
Center cropping...
Standardizing data...
Converting labels to categorical...
Converting labels to categorical...

Dataset :Electron Photon 1
╒════════╤═════════════════╤════════════════╤═════════════════╤═══════════╕
│ Data   │ Train size      │ Val size       │ Test size       │ Dims      │
╞════════╪═════════════════╪════════════════╪═════════════════╪═══════════╡
│ X      │ (8100, 8, 8, 1) │ (900, 8, 8, 1) │ (1000, 8, 8, 1) │ (8, 8, 1) │
├────────┼─────────────────┼────────────────┼─────────────────┼───────────┤
│ y      │ (8100, 2)       │ (900, 2)       │ (1000, 2)       │ (2,)      │
╘════════╧═════════════════╧════════════════╧═════════════════╧═══════════╛

╒══════════════╤═══════╤═══════╤════════╤═══════╤══════════════════════════╕
│ Type         │   Min │   Max │   Mean │   Std │ Samples for each class   │
╞══════════════╪═══════╪═══════╪════════╪═══════╪══════════════════════════╡
│ Train Images │ -2.88 │ 59.59 │   0    │  1.01 │ [4050, 4050]             │

## Hyperparameters

In [319]:
input_dims = data.config()['input_dims']

In [320]:
def get_out_shape(in_shape, k, s, padding):
    in_shape = (1,) + in_shape
    a = np.random.uniform(size = (in_shape))
    dn = jax.lax.conv_dimension_numbers(a.shape, (1,1,k[0],k[1]), ('NHWC', 'IOHW', 'NHWC'))
    out = jax.lax.conv_general_dilated_patches(lhs = a,
                                           filter_shape= k,
                                           window_strides=s,
                                           padding=padding,
                                           dimension_numbers=dn 
                                    )
    return out.shape

In [324]:
initializer = he_uniform()

# Get qlayer sizes
def get_qlayer_sizes(template, n_l, n_q, k_size):
    if template == 'NQubitPQCSparse':
        return {
            'w': (n_qubits, 4),
            'b':(n_qubits,1)
        }

def random_qlayer_params(size, key, scale=1e-1):
    return initializer(key, size)
    return scale * random.normal(key, size)

def init_qnetwork_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [[random_qlayer_params(size, key) for size, key in zip(sizes.values(), keys)]]
 

# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_clayer_params(m, n, key, scale=1e-1):
    w_key, b_key = random.split(key)
    return initializer(w_key, (n,m)), random.normal(b_key, (n,))
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_clayer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

kernel_size = args.kernel_size
strides = args.strides
padding = args.padding
clayer_sizes = args.clayer_sizes

template = args.template
n_layers = args.n_layers
n_qubits = args.n_qubits


conv_out_shape = get_out_shape(input_dims, kernel_size, strides, padding)
num_pixels = np.prod(conv_out_shape[:-1])
qlayer_sizes = get_qlayer_sizes(template, n_layers, n_qubits, kernel_size)
clayer_sizes = [num_pixels] + clayer_sizes

params = []
params += init_qnetwork_params(qlayer_sizes, random.PRNGKey(0))
# params += init_qnetwork_params(qlayer_sizes, random.PRNGKey(1))
params += init_network_params(clayer_sizes, random.PRNGKey(2))

In [325]:
for i in params:
    for j in i:
        print(j.shape, end = ' ')
    print()

(9, 4) (9, 1) 
(8, 64) (8,) 
(2, 8) (2,) 


## QLayers

In [362]:
dev = qml.device('default.qubit.jax', wires=n_qubits+1)
qubits =list(range(n_qubits))

@jax.jit
@qml.qnode(dev, interface='jax')
def NQubitPQCSparse(inputs, w, b):
    inputs = jnp.transpose(inputs)
    
    # Pauli feature map
    for l in range(args.n_elayers):
        for i in qubits:
            qml.Hadamard(i)
            qml.RZ(2*inputs[i,:],wires=i)
    
    # Efficient SU2
    for l in range(n_layers):
        for i in qubits:
            qml.RY(w[i,0], wires=i)
            qml.RZ(w[i,1], wires=i)
        
        # circular entanglement
        for p, q in zip(qubits, qubits[1:]):
            qml.CNOT((p,q))
        if len(qubits) > 2:
            qml.CNOT((qubits[-1], qubits[0]))

        for i in qubits:
            qml.RY(w[i,2], wires=i)
            qml.RZ(w[i,3], wires=i)

    # Readout 
    for i in qubits:
        qml.IsingXX(phi=b[i,0], wires = (i,n_qubits))
    
    return [qml.expval(qml.PauliZ(n_qubits))]

In [363]:
def get_node(template):
    if template == 'NQubitPQC':
        return NQubitPQC
    elif template == 'SimpleDRC':
        return SimpleDRC
    elif template == 'NQubitPQCSparse':
        return NQubitPQCSparse

In [364]:
def qconv(x, *qweights):
    x = jnp.expand_dims(x,axis=0)
    dn = jax.lax.conv_dimension_numbers(x.shape, 
                                        (1,1,kernel_size[0],kernel_size[1]), 
                                        ('NHWC', 'IOHW', 'NHWC'))
    x = jax.lax.conv_general_dilated_patches(lhs = x,
                                               filter_shape= kernel_size,
                                               window_strides=strides,
                                               padding=padding,
                                               dimension_numbers=dn 
                                              )
    iters = x.shape[1:3]
    x = jnp.reshape(x, (-1, np.prod(kernel_size)))
    
    x = get_node(template)(x, *qweights)
    x = jnp.reshape(x, iters + (1,))
    return x

In [365]:
dev = qml.device("default.qubit", wires=n_qubits)
qnode = qml.QNode(get_node(template), dev)

inputs = np.random.uniform(size = (10, np.prod(kernel_size)))
weights = params[0]
drawer = qml.draw(qnode, expansion_strategy="device")
print(drawer(inputs,*weights))

0: ──H──RZ(M7)──RY(-0.48)──RZ(0.25)──╭●──────────────────────╭X──RY(-0.26)──RZ(0.25)─
1: ──H──RZ(M5)──RY(0.34)───RZ(-0.29)─╰X─╭●───────────────────│───RY(0.45)───RZ(0.27)─
2: ──H──RZ(M8)──RY(-0.56)──RZ(0.45)─────╰X─╭●────────────────│───RY(0.27)───RZ(-0.47)
3: ──H──RZ(M1)──RY(-0.61)──RZ(0.45)────────╰X─╭●─────────────│───RY(0.12)───RZ(0.74)─
4: ──H──RZ(M3)──RY(0.56)───RZ(0.40)───────────╰X─╭●──────────│───RY(0.03)───RZ(-0.45)
5: ──H──RZ(M0)──RY(0.37)───RZ(-0.47)─────────────╰X─╭●───────│───RY(0.41)───RZ(-0.11)
6: ──H──RZ(M2)──RY(0.75)───RZ(0.13)─────────────────╰X─╭●────│───RY(0.80)───RZ(-0.74)
7: ──H──RZ(M6)──RY(-0.79)──RZ(-0.46)───────────────────╰X─╭●─│───RY(-0.20)──RZ(-0.45)
8: ──H──RZ(M4)──RY(0.46)───RZ(0.24)───────────────────────╰X─╰●──RY(0.68)───RZ(-0.01)
9: ──────────────────────────────────────────────────────────────────────────────────

──╭IsingXX(-0.25)──────────────────────────────────────────────────────────────────────────────
──│───────────────╭IsingXX(-0.03)──────────

## Auto-Batching Predictions

In [366]:
from jax.scipy.special import logsumexp

def relu(x):
    return jnp.maximum(0, x)

def forward(params, image):
  # per-example predictions
    activations = qconv(image, *params[0])
    activations += image
    activations = relu(activations)
    activations = jnp.reshape(activations, (-1))
    for w, b in params[1:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)

In [367]:
def forwardx(params, image):
  # per-example predictions
    activations = qconv(image, *params[0])
    activations = qconv(activations, *params[1])
    activations += image
    activations = relu(activations)
    activations = jnp.reshape(activations, (-1))
    for w, b in params[2:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)

In [368]:
# This works on single examples
random_flattened_image = random.normal(random.PRNGKey(1), input_dims)
random_flattened_image = jnp.floor(random_flattened_image*10)
preds = forward(params,  random_flattened_image)
print(preds)

[   0.      -155.40536]


In [369]:
# Doesn't work with a batch
random_flattened_images = random.normal(random.PRNGKey(1), (2,)+ input_dims)
random_flattened_images = jnp.floor(random_flattened_images*10)
# try:
#     preds = predict(params, random_flattened_images)
# except TypeError:
#     print('Invalid shapes!')

In [370]:
# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_forward = vmap(forward, in_axes=(None,0))

# `batched_predict` has the same call signature as `predict`
batched_preds = batched_forward(params, random_flattened_images)
print(batched_preds)

[[   0.       -165.11536 ]
 [   0.        -44.895844]]


## Utility and loss functions

In [371]:
from sklearn.metrics import roc_auc_score

def accuracy(y_true, y_pred):
    target_class = jnp.argmax(y_true, axis=1)
    predicted_class = jnp.argmax(y_pred, axis=1)
    return jnp.mean(predicted_class == target_class)
 

def loss_fn(params, images, targets):
    preds = batched_forward(params, images)
    loss_value = -jnp.mean(preds * targets)
    return loss_value, preds

@jit
def update(opt_state, params, x, y):
    (loss_value, preds), grads = jax.value_and_grad(loss_fn, has_aux=True)(params, x, y)
    acc = accuracy(y,preds)
    
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    
    return params, opt_state, loss_value, acc 


def step(params,x,y):
    loss_value, preds = loss_fn(params, x, y)
    acc = accuracy(y, preds)
    return loss_value, acc

def evaluate(params, ds):
    losses = []
    accs = []
    with tqdm(tfds.as_numpy(ds), unit="batch") as tepoch:
        for x, y in tepoch:
            loss_value, acc = step(params, x, y)
            losses.append(loss_value)
            accs.append(acc)
            tepoch.set_postfix(loss=loss_value, acc=acc)
            
    return jnp.mean(np.array(losses)), jnp.mean(np.array(accs))

def predict(params, ds):
    preds = []
    y_true = []
    with tqdm(tfds.as_numpy(ds), unit="batch") as tepoch:
        for x, y in tepoch:
            preds += list(batched_forward(params, x))
            y_true += list(y)
    
    return np.array(preds), np.array(y_true)

## Training loop

In [372]:
lr = 1e-3

In [373]:
schedule_fn = optax.linear_schedule(transition_steps=150,
                                    init_value=0.2,
                                    end_value=1e-7,
                                    )
# Defining an optimizer in Jax 
# optimizer = optax.adam(learning_rate=schedule_fn)

print(lr)
optimizer = optax.adam(learning_rate=args.learning_rate)
# optimizer = optax.adam(learning_rate=lr)
opt_state = optimizer.init(params)
lr = (lr*0.1)

0.001


In [374]:
import time

epochs = args.epochs
# epochs = 5
for epoch in range(20):
    start_time = time.time()

    with tqdm(tfds.as_numpy(data.train_ds), unit="batch") as tepoch:
        for x, y in tepoch:
            tepoch.set_description(f"Epoch {epoch}")
            params, opt_state, loss_value, acc = update(opt_state, params, x, y)
            tepoch.set_postfix(loss=loss_value, acc=acc)
        
    epoch_time = time.time() - start_time

    val_loss, val_acc = evaluate(params, data.val_ds)
    print('val_loss: {} - val_acc: {}-  time: {}'.format(val_loss, val_acc, epoch_time))
    
    if args.wandb:
        wandb.log({"accuracy": acc, 
                   "val_accuracy": val_acc, 
                   'loss':loss_value, 
                   'val_loss':val_loss})


Epoch 0:   0%|                                                                             | 0/64 [00:54<?, ?batch/s]


KeyboardInterrupt: 

In [205]:
test_loss, test_acc = evaluate(params, data.test_ds)
test_loss, test_acc

100%|█████████████████████████████████████████████████| 766/766 [00:14<00:00, 53.98batch/s, acc=0.7, loss=0.29832602]


(DeviceArray(0.30152068, dtype=float32), DeviceArray(0.6875367, dtype=float32))

In [206]:
from sklearn.metrics import roc_auc_score

out,y_test = predict(params, data.test_ds)
# _, y_test = tf_ds_to_numpy(data.test_ds)
test_auc = roc_auc_score(y_test, out)
test_auc

100%|███████████████████████████████████████████████████████████████████████████| 766/766 [00:16<00:00, 46.29batch/s]


0.7415369426280717

In [207]:
if args.wandb:
    wandb.run.summary['test_loss'] = test_loss
    wandb.run.summary['test_acc'] = test_acc
    wandb.run.summary['test_auc'] = test_auc
    y = y_test.argmax(axis=1)
    preds = out.argmax(axis=1)
    probs = out
    classes = data.mapping

    roc_curve = wandb.sklearn.plot_roc(y, probs, classes)
    confusion_matrix = wandb.sklearn.plot_confusion_matrix(y, preds, classes)

    wandb.log({"roc_curve": roc_curve})
    wandb.log({"confusion_matrix": confusion_matrix})



In [208]:
if args.wandb:
    wandb.finish()

VBox(children=(Label(value='0.360 MB of 0.360 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁▆▅▄▆▆▆▄▇▅▇▆█▆▃▆▄▇▅▇▄▆▆▇▇▇▅▇▇▆▇▄▄▇▅▆▆▅▆▆
loss,█▅▅▇▃▄▄▆▄▅▃▃▁▃▅▄▅▂▆▃▅▄▅▄▂▃▅▂▁▃▂▆▄▃▅▄▃▄▃▂
val_accuracy,▁▅▆▇▇▇▇▇▇▇█▇████████████████████████████
val_loss,█▄▃▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
accuracy,0.66667
loss,0.30837
test_acc,0.68754
test_auc,0.74154
test_loss,0.30152
val_accuracy,0.68685
val_loss,0.3013


In [32]:
for i in range(200):
    print(i, schedule_fn(i))

0 0.2
1 0.19866668
2 0.19733334
3 0.19600001
4 0.19466668
5 0.19333333
6 0.192
7 0.19066668
8 0.18933333
9 0.18800001
10 0.18666668
11 0.18533334
12 0.18400002
13 0.18266669
14 0.18133333
15 0.18
16 0.17866668
17 0.17733334
18 0.17600001
19 0.17466669
20 0.17333335
21 0.17200002
22 0.1706667
23 0.16933335
24 0.16800003
25 0.16666669
26 0.16533335
27 0.16400002
28 0.1626667
29 0.16133335
30 0.16000003
31 0.15866669
32 0.15733334
33 0.15600002
34 0.15466669
35 0.15333335
36 0.15200002
37 0.1506667
38 0.14933336
39 0.14800003
40 0.1466667
41 0.14533336
42 0.14400004
43 0.14266671
44 0.14133336
45 0.14000003
46 0.1386667
47 0.13733336
48 0.13600004
49 0.1346667
50 0.13333336
51 0.13200003
52 0.1306667
53 0.12933336
54 0.12800004
55 0.12666671
56 0.12533337
57 0.124000035
58 0.1226667
59 0.121333376
60 0.12000004
61 0.11866671
62 0.11733337
63 0.116000034
64 0.1146667
65 0.113333374
66 0.11200004
67 0.1106667
68 0.109333366
69 0.10800003
70 0.10666671
71 0.10533337
72 0.10400004
73 0.102666