In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

from importlib.util import find_spec
if find_spec("qml_hep_lhc") is None:
    import sys
    sys.path.append('..')

In [2]:
from qml_hep_lhc.data import ElectronPhoton, MNIST, QuarkGluon
from qml_hep_lhc.data.utils import tf_ds_to_numpy
import argparse
import wandb

import pennylane as qml
import jax.numpy as jnp
import jax
import optax
from jax.nn.initializers import he_uniform
from jax import grad, jit, vmap
from jax import random
import tensorflow_datasets as tfds
from tqdm import tqdm
import numpy as np
import tensorflow as tf

# Added to silence some warnings.
# from jax.config import config
# config.update("jax_enable_x64", True)

import matplotlib.pyplot as plt
import time

In [3]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [100]:
args = argparse.Namespace()

# Data
# args.center_crop = 0.2
# args.resize = [8,8]
args.standardize = 1
args.power_transform = 1
# args.binary_data = [3,6]
# args.percent_samples = 0.01
# args.processed = 1
args.dataset_type = '2'
args.labels_to_categorical = 1
args.batch_size = 128
args.validation_split = 0.05

# Base Model
args.wandb = True
args.epochs = 20
args.learning_rate = 0.001

# Quantum CNN Parameters
args.n_layers = 2
args.n_qubits = 4
args.template = 'NQubitPQCSparse'
args.initializer = 'he_uniform'

args.kernel_size = (3,3)
args.strides = (2,2)
args.padding = "VALID"

args.clayer_sizes = [8, 2]

In [101]:
if args.wandb:
     wandb.init(project='qml-hep-lhc', config = vars(args))

In [76]:
data = QuarkGluon(args)
data.prepare_data()
data.setup()
print(data)

Performing power transform...
Standardizing data...
Converting labels to categorical...
Converting labels to categorical...

Dataset :Quark Gluon 2
╒════════╤════════════════════╤═══════════════════╤════════════════════╤═════════════╕
│ Data   │ Train size         │ Val size          │ Test size          │ Dims        │
╞════════╪════════════════════╪═══════════════════╪════════════════════╪═════════════╡
│ X      │ (95000, 40, 40, 1) │ (5000, 40, 40, 1) │ (20000, 40, 40, 1) │ (40, 40, 1) │
├────────┼────────────────────┼───────────────────┼────────────────────┼─────────────┤
│ y      │ (95000, 2)         │ (5000, 2)         │ (20000, 2)         │ (2,)        │
╘════════╧════════════════════╧═══════════════════╧════════════════════╧═════════════╛

╒══════════════╤═══════╤═══════╤════════╤═══════╤══════════════════════════╕
│ Type         │   Min │   Max │   Mean │   Std │ Samples for each class   │
╞══════════════╪═══════╪═══════╪════════╪═══════╪══════════════════════════╡
│ Train Ima

## Hyperparameters

In [102]:
input_dims = data.config()['input_dims']

In [103]:
def get_out_shape(in_shape, k, s, padding):
    in_shape = (1,) + in_shape
    a = np.random.uniform(size = (in_shape))
    dn = jax.lax.conv_dimension_numbers(a.shape, (1,1,k[0],k[1]), ('NHWC', 'IOHW', 'NHWC'))
    out = jax.lax.conv_general_dilated_patches(lhs = a,
                                           filter_shape= k,
                                           window_strides=s,
                                           padding=padding,
                                           dimension_numbers=dn 
                                    )
    return out.shape

In [104]:
initializer = he_uniform()

# Get qlayer sizes
def get_qlayer_sizes(template, n_l, n_q, k_size):
    if template == 'NQubitPQCSparse':
        return {
            'w': (n_l, n_q,3,np.prod(k_size)),
            'b': (n_l,n_q,3,1)
        }
    elif template == 'SimpleDRC':
        return {
            'w': (n_l+1,n_q,3),
            's': (n_l,n_q),
            'b': (n_l,n_q)
        }
    elif template == 'NQubitPQC':
        assert np.prod(k_size)%3 == 0
        return {
            'w': (n_l,n_q,np.prod(k_size)),
            'b': (n_l,n_q,np.prod(k_size))
        }

def random_qlayer_params(size, key, scale=1e-1):
    return initializer(key, size)
    return scale * random.normal(key, size)

def init_qnetwork_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [[random_qlayer_params(size, key) for size, key in zip(sizes.values(), keys)]]
 

# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_clayer_params(m, n, key, scale=1e-1):
    w_key, b_key = random.split(key)
    return initializer(w_key, (n,m)), random.normal(b_key, (n,))
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_clayer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

kernel_size = args.kernel_size
strides = args.strides
padding = args.padding
clayer_sizes = args.clayer_sizes

template = args.template
n_layers = args.n_layers
n_qubits = args.n_qubits


conv_out_shape = get_out_shape(input_dims, kernel_size, strides, padding)
num_pixels = np.prod(conv_out_shape[:-1])*n_qubits
qlayer_sizes = get_qlayer_sizes(template, n_layers, n_qubits, kernel_size)
clayer_sizes = [num_pixels] + clayer_sizes

params = []
params += init_qnetwork_params(qlayer_sizes, random.PRNGKey(0))
# params += init_qnetwork_params(qlayer_sizes, random.PRNGKey(1))
params += init_network_params(clayer_sizes, random.PRNGKey(2))

  rhs = jnp.eye(spatial_size, dtype=lhs.dtype).reshape(filter_shape * 2)


In [105]:
for i in params:
    for j in i:
        print(j.shape, end = ' ')
    print()

(2, 4, 3, 9) (2, 4, 3, 1) 
(8, 1444) (8,) 
(2, 8) (2,) 


## QLayers

In [106]:
dev = qml.device('default.qubit.jax', wires=n_qubits)
qubits =list(range(n_qubits))

@jax.jit
@qml.qnode(dev, interface='jax')
def NQubitPQCSparse(inputs, w, b):
    z = jnp.dot(w, jnp.transpose(inputs))+ b

    for q in qubits:
        qml.Hadamard(wires=q)
    
    for l in range(n_layers):
        for q in qubits:
            qml.Rot(z[l,q,0], z[l,q,1], z[l,q,2], wires= q)
        if (l & 1):
            for q0, q1 in zip(qubits[1::2], qubits[2::2] + [qubits[0]]):
                qml.CZ((q0,q1))
        else:
            for q0, q1 in zip(qubits[0::2], qubits[1::2]):
                qml.CZ((q0,q1))
   
#     return qml.expval(qml.PauliZ(qubits[-1]))
    return [qml.expval(qml.PauliZ(q)) for q in qubits]

In [107]:
dev = qml.device('default.qubit.jax', wires=n_qubits)
qubits =list(range(n_qubits))

@jax.jit
@qml.qnode(dev, interface='jax')
def NQubitPQC(inputs, w, b):
    steps = inputs.shape[-1]//3
    for q in qubits:
        qml.Hadamard(wires=q)
    
    for l in range(n_layers):
        for q in qubits:
            for i in range(steps):
                z = jnp.transpose(jnp.multiply(inputs[:,3*i:3*i+3],w[l,q,3*i:3*i+3]) + b[l,q,3*i:3*i+3])
                qml.RX(z[0], wires=q)
                qml.RY(z[1], wires=q)
                qml.RZ(z[2], wires=q)
                
        if (l & 1):
            for q0, q1 in zip(qubits[1::2], qubits[2::2] + [qubits[0]]):
                qml.CZ((q0,q1))
        else:
            for q0, q1 in zip(qubits[0::2], qubits[1::2]):
                qml.CZ((q0,q1))

#     return qml.expval(qml.PauliZ(qubits[-1]))
    return [qml.expval(qml.PauliZ(q)) for q in qubits]

In [108]:
dev = qml.device('default.qubit.jax', wires=n_qubits)
qubits =list(range(n_qubits))

@jax.jit
@qml.qnode(dev, interface='jax')
def SimpleDRC(inputs, w, s, b):
    for l in range(n_layers):
        x = jnp.transpose(jnp.multiply(s[l],inputs) + b[l]) 
        for q in qubits:
            qml.Rot(*w[l,q], wires = q)
        for q0, q1 in zip(qubits, qubits[1:]):
            qml.CZ((q0, q1))
        if len(qubits) != 2:
            qml.CZ((qubits[0], qubits[-1]))
        for q in qubits:
            qml.RX(x[q], wires=q)
    for q in qubits:
            qml.Rot(*w[n_layers,q], wires = q)
   
    return [qml.expval(qml.PauliZ(q)) for q in qubits]

In [109]:
def get_node(template):
    if template == 'NQubitPQC':
        return NQubitPQC
    elif template == 'SimpleDRC':
        return SimpleDRC
    elif template == 'NQubitPQCSparse':
        return NQubitPQCSparse

In [110]:
def qconv(x, qweights):
    x = jnp.expand_dims(x,axis=0)
    dn = jax.lax.conv_dimension_numbers(x.shape, 
                                        (1,1,kernel_size[0],kernel_size[1]), 
                                        ('NHWC', 'IOHW', 'NHWC'))
    x = jax.lax.conv_general_dilated_patches(lhs = x,
                                               filter_shape= kernel_size,
                                               window_strides=strides,
                                               padding=padding,
                                               dimension_numbers=dn 
                                              )
    iters = x.shape[1:3]
    x = jnp.reshape(x, (-1, np.prod(kernel_size)))
    
    x = get_node(template)(x, *qweights)
    
    x = jnp.reshape(x, iters + (n_qubits,))
    return x

In [111]:
dev = qml.device("default.qubit", wires=n_qubits)
qnode = qml.QNode(get_node(template), dev)

inputs = np.random.uniform(size = (10,np.prod(kernel_size)))
weights = params[0]
drawer = qml.draw(qnode, expansion_strategy="device")
print(drawer(inputs,*weights))

0: ──H──Rot─╭●──Rot────╭Z─┤  <Z>
1: ──H──Rot─╰Z──Rot─╭●─│──┤  <Z>
2: ──H──Rot─╭●──Rot─╰Z─│──┤  <Z>
3: ──H──Rot─╰Z──Rot────╰●─┤  <Z>


## Auto-Batching Predictions

In [112]:
from jax.scipy.special import logsumexp

def relu(x):
    return jnp.maximum(0, x)

def forward(params, image):
  # per-example predictions
    activations = qconv(image, params[0])
#     activations += image
#     activations = relu(activations)
    activations = jnp.reshape(activations, (-1))
    for w, b in params[1:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)

In [113]:
def forwardx(params, image):
  # per-example predictions
    activations = qconv(image, params[0])
    activations = relu(activations)
    activations = qconv(activations, params[1])
    activations += image
    activations = relu(activations)
    
    activations = jnp.reshape(activations, (-1))
    for w, b in params[2:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)

In [114]:
# This works on single examples
random_flattened_image = random.normal(random.PRNGKey(1), input_dims)
random_flattened_image = jnp.floor(random_flattened_image*10)
preds = forward(params,  random_flattened_image)
print(preds)

[-0.11211586 -2.243807  ]


In [115]:
# Doesn't work with a batch
random_flattened_images = random.normal(random.PRNGKey(1), (2,)+ input_dims)
random_flattened_images = jnp.floor(random_flattened_images*10)
# try:
#     preds = predict(params, random_flattened_images)
# except TypeError:
#     print('Invalid shapes!')

In [116]:
# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_forward = vmap(forward, in_axes=(None,0))

# `batched_predict` has the same call signature as `predict`
batched_preds = batched_forward(params, random_flattened_images)
print(batched_preds)

[[  0.         -19.149054  ]
 [ -2.0530276   -0.13735485]]


## Utility and loss functions

In [117]:
from sklearn.metrics import roc_auc_score

def accuracy(y_true, y_pred):
    target_class = jnp.argmax(y_true, axis=1)
    predicted_class = jnp.argmax(y_pred, axis=1)
    return jnp.mean(predicted_class == target_class)
 

def loss_fn(params, images, targets):
    preds = batched_forward(params, images)
    loss_value = -jnp.mean(preds * targets)
    return loss_value, preds

@jit
def update(opt_state, params, x, y):
    (loss_value, preds), grads = jax.value_and_grad(loss_fn, has_aux=True)(params, x, y)
    acc = accuracy(y,preds)
    
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    
    return params, opt_state, loss_value, acc 


def step(params,x,y):
    loss_value, preds = loss_fn(params, x, y)
    acc = accuracy(y, preds)
    return loss_value, acc

def evaluate(params, ds):
    losses = []
    accs = []
    with tqdm(tfds.as_numpy(ds), unit="batch") as tepoch:
        for x, y in tepoch:
            loss_value, acc = step(params, x, y)
            losses.append(loss_value)
            accs.append(acc)
            tepoch.set_postfix(loss=loss_value, acc=acc)
            
    return jnp.mean(np.array(losses)), jnp.mean(np.array(accs))

def predict(params, ds):
    preds = []
    y_true = []
    with tqdm(tfds.as_numpy(ds), unit="batch") as tepoch:
        for x, y in tepoch:
            preds += list(batched_forward(params, x))
            y_true += list(y)
    
    return np.array(preds), np.array(y_true)

## Training loop

In [118]:
lr = 1e-4

In [119]:
schedule_fn = optax.linear_schedule(transition_steps=150,
                                    init_value=0.2,
                                    end_value=1e-7,
                                    )
# Defining an optimizer in Jax 
# optimizer = optax.adam(learning_rate=schedule_fn)

print(lr)
optimizer = optax.adam(learning_rate=args.learning_rate)
# optimizer = optax.adam(learning_rate=lr)
opt_state = optimizer.init(params)
lr = (lr*0.1)

0.0001


In [None]:
import time

epochs = args.epochs
# epochs = 5
for epoch in range(epochs):
    start_time = time.time()

    with tqdm(tfds.as_numpy(data.train_ds), unit="batch") as tepoch:
        for x, y in tepoch:
            tepoch.set_description(f"Epoch {epoch}")
            params, opt_state, loss_value, acc = update(opt_state, params, x, y)
            tepoch.set_postfix(loss=loss_value, acc=acc)
        
    epoch_time = time.time() - start_time

    val_loss, val_acc = evaluate(params, data.val_ds)
    print('val_loss: {} - val_acc: {}-  time: {}'.format(val_loss, val_acc, epoch_time))
    
    if args.wandb:
        wandb.log({"accuracy": acc, 
                   "val_accuracy": val_acc, 
                   'loss':loss_value, 
                   'val_loss':val_loss})


Epoch 0: 100%|██████████████████████████████████| 743/743 [09:09<00:00,  1.35batch/s, acc=0.5416667, loss=0.34963474]
100%|█████████████████████████████████████████████████| 40/40 [00:12<00:00,  3.33batch/s, acc=0.375, loss=0.38454804]


val_loss: 0.3442888855934143 - val_acc: 0.5472656488418579-  time: 549.2657237052917


Epoch 1: 100%|███████████████████████████████████| 743/743 [08:55<00:00,  1.39batch/s, acc=0.2916667, loss=0.3705942]
100%|█████████████████████████████████████████████████| 40/40 [00:05<00:00,  7.71batch/s, acc=0.625, loss=0.33907646]


val_loss: 0.34018948674201965 - val_acc: 0.573046863079071-  time: 535.5697953701019


Epoch 2: 100%|████████████████████████████████████████| 743/743 [08:55<00:00,  1.39batch/s, acc=0.5, loss=0.36190435]
100%|█████████████████████████████████████████████████| 40/40 [00:05<00:00,  7.71batch/s, acc=0.875, loss=0.32057077]


val_loss: 0.3337971270084381 - val_acc: 0.6048828363418579-  time: 535.4961884021759


Epoch 3: 100%|████████████████████████████████████████| 743/743 [08:55<00:00,  1.39batch/s, acc=0.5, loss=0.37026352]
100%|██████████████████████████████████████████████████| 40/40 [00:05<00:00,  7.72batch/s, acc=0.75, loss=0.36012775]


val_loss: 0.3312808573246002 - val_acc: 0.6117187738418579-  time: 535.2245771884918


Epoch 4: 100%|████████████████████████████████████████| 743/743 [08:55<00:00,  1.39batch/s, acc=0.5, loss=0.36378896]
100%|█████████████████████████████████████████████████| 40/40 [00:05<00:00,  7.74batch/s, acc=0.875, loss=0.25919777]


val_loss: 0.32500097155570984 - val_acc: 0.624218761920929-  time: 535.2908368110657


Epoch 5: 100%|███████████████████████████████████████| 743/743 [08:55<00:00,  1.39batch/s, acc=0.625, loss=0.3116935]
100%|█████████████████████████████████████████████████████| 40/40 [00:05<00:00,  7.74batch/s, acc=0.75, loss=0.29148]


val_loss: 0.32319164276123047 - val_acc: 0.6244140863418579-  time: 535.302996635437


Epoch 6: 100%|██████████████████████████████████| 743/743 [08:55<00:00,  1.39batch/s, acc=0.5833334, loss=0.32310146]
100%|███████████████████████████████████████████████████| 40/40 [00:05<00:00,  7.74batch/s, acc=0.5, loss=0.38066936]


val_loss: 0.32493677735328674 - val_acc: 0.620312511920929-  time: 535.4283833503723


Epoch 7: 100%|███████████████████████████████████| 743/743 [08:55<00:00,  1.39batch/s, acc=0.5833334, loss=0.3202511]
100%|█████████████████████████████████████████████████| 40/40 [00:05<00:00,  7.71batch/s, acc=0.125, loss=0.46611574]


val_loss: 0.3265770375728607 - val_acc: 0.6162109375-  time: 535.4490578174591


Epoch 8: 100%|███████████████████████████████████████| 743/743 [08:55<00:00,  1.39batch/s, acc=0.75, loss=0.26556504]
100%|██████████████████████████████████████████████████| 40/40 [00:05<00:00,  7.74batch/s, acc=0.75, loss=0.32307842]


val_loss: 0.3219431936740875 - val_acc: 0.6363281607627869-  time: 535.3652205467224


Epoch 9: 100%|██████████████████████████████████████| 743/743 [08:55<00:00,  1.39batch/s, acc=0.625, loss=0.35050088]
100%|█████████████████████████████████████████████████| 40/40 [00:05<00:00,  7.73batch/s, acc=0.625, loss=0.32838377]


val_loss: 0.3223625719547272 - val_acc: 0.6357421875-  time: 535.4385662078857


Epoch 10: 100%|█████████████████████████████████| 743/743 [08:55<00:00,  1.39batch/s, acc=0.5833334, loss=0.32374728]
100%|██████████████████████████████████████████████████| 40/40 [00:05<00:00,  7.72batch/s, acc=0.75, loss=0.23875633]


val_loss: 0.3204209506511688 - val_acc: 0.6363281607627869-  time: 535.3487136363983


Epoch 11: 100%|█████████████████████████████████| 743/743 [08:55<00:00,  1.39batch/s, acc=0.5833334, loss=0.34833393]
100%|██████████████████████████████████████████████████| 40/40 [00:05<00:00,  7.73batch/s, acc=0.75, loss=0.22770637]


val_loss: 0.3192037045955658 - val_acc: 0.6429687738418579-  time: 535.4102993011475


Epoch 12:  43%|██████████████▌                   | 317/743 [03:49<05:07,  1.39batch/s, acc=0.671875, loss=0.30435002]

In [None]:
test_loss, test_acc = evaluate(params, data.test_ds)
test_loss, test_acc

In [None]:
from sklearn.metrics import roc_auc_score

out,y_test = predict(params, data.test_ds)
# _, y_test = tf_ds_to_numpy(data.test_ds)
test_auc = roc_auc_score(y_test, out)
test_auc

In [None]:
if args.wandb:
    wandb.run.summary['test_loss'] = test_loss
    wandb.run.summary['test_acc'] = test_acc
    wandb.run.summary['test_auc'] = test_auc
    y = y_test.argmax(axis=1)
    preds = out.argmax(axis=1)
    probs = out
    classes = data.mapping

    roc_curve = wandb.sklearn.plot_roc(y, probs, classes)
    confusion_matrix = wandb.sklearn.plot_confusion_matrix(y, preds, classes)

    wandb.log({"roc_curve": roc_curve})
    wandb.log({"confusion_matrix": confusion_matrix})

In [None]:
if args.wandb:
    wandb.finish()

In [32]:
for i in range(200):
    print(i, schedule_fn(i))

0 0.2
1 0.19866668
2 0.19733334
3 0.19600001
4 0.19466668
5 0.19333333
6 0.192
7 0.19066668
8 0.18933333
9 0.18800001
10 0.18666668
11 0.18533334
12 0.18400002
13 0.18266669
14 0.18133333
15 0.18
16 0.17866668
17 0.17733334
18 0.17600001
19 0.17466669
20 0.17333335
21 0.17200002
22 0.1706667
23 0.16933335
24 0.16800003
25 0.16666669
26 0.16533335
27 0.16400002
28 0.1626667
29 0.16133335
30 0.16000003
31 0.15866669
32 0.15733334
33 0.15600002
34 0.15466669
35 0.15333335
36 0.15200002
37 0.1506667
38 0.14933336
39 0.14800003
40 0.1466667
41 0.14533336
42 0.14400004
43 0.14266671
44 0.14133336
45 0.14000003
46 0.1386667
47 0.13733336
48 0.13600004
49 0.1346667
50 0.13333336
51 0.13200003
52 0.1306667
53 0.12933336
54 0.12800004
55 0.12666671
56 0.12533337
57 0.124000035
58 0.1226667
59 0.121333376
60 0.12000004
61 0.11866671
62 0.11733337
63 0.116000034
64 0.1146667
65 0.113333374
66 0.11200004
67 0.1106667
68 0.109333366
69 0.10800003
70 0.10666671
71 0.10533337
72 0.10400004
73 0.102666