In [3]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

from importlib.util import find_spec
if find_spec("qml_hep_lhc") is None:
    import sys
    sys.path.append('../..')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
from qml_hep_lhc.data import ElectronPhoton, MNIST, QuarkGluon
from qml_hep_lhc.data.utils import tf_ds_to_numpy
import argparse
import wandb

import pennylane as qml
import jax.numpy as jnp
import jax
import optax
from jax.nn.initializers import he_uniform
from jax import grad, jit, vmap
from jax import random
import tensorflow_datasets as tfds
from tqdm import tqdm
import numpy as np
import tensorflow as tf

# Added to silence some warnings.
# from jax.config import config
# config.update("jax_enable_x64", True)

import matplotlib.pyplot as plt
import time

2022-10-23 22:02:30.513500: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-10-23 22:02:30.513610: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


In [5]:
jax.devices()



[CpuDevice(id=0)]

In [158]:
args = argparse.Namespace()

# Data
args.center_crop = 0.7
args.resize = [8,8]
args.standardize = 1
# args.power_transform = 1
args.binary_data = [0,1]
args.percent_samples = 0.01
# args.processed = 1
args.dataset_type = '1'
args.labels_to_categorical = 1
args.batch_size = 128
args.validation_split = 0.1
args.num_classes = 5

# Base Model
args.wandb = False
args.epochs = 50
args.learning_rate = 0.001

# Quantum CNN Parameters
args.n_layers = 1
args.n_qubits = 1
args.template = 'NQubitPQC'
args.initializer = 'he_uniform'

args.kernel_size = (3,3)
args.strides = (1,1)
args.padding = "SAME"

args.clayer_sizes = [8, 2]

In [159]:
if args.wandb:
     wandb.init(project='qml-hep-lhc', config = vars(args))

In [160]:
data = MNIST(args)
data.prepare_data()
data.setup()
print(data)

Binarizing data...
Binarizing data...
Center cropping...
Center cropping...
Resizing data...
Resizing data...
Standardizing data...
Converting labels to categorical...
Converting labels to categorical...

Dataset :MNIST
╒════════╤════════════════╤═══════════════╤═══════════════╤═══════════╕
│ Data   │ Train size     │ Val size      │ Test size     │ Dims      │
╞════════╪════════════════╪═══════════════╪═══════════════╪═══════════╡
│ X      │ (108, 8, 8, 1) │ (12, 8, 8, 1) │ (20, 8, 8, 1) │ (8, 8, 1) │
├────────┼────────────────┼───────────────┼───────────────┼───────────┤
│ y      │ (108, 2)       │ (12, 2)       │ (20, 2)       │ (2,)      │
╘════════╧════════════════╧═══════════════╧═══════════════╧═══════════╛

╒══════════════╤═══════╤═══════╤════════╤═══════╤══════════════════════════╕
│ Type         │   Min │   Max │   Mean │   Std │ Samples for each class   │
╞══════════════╪═══════╪═══════╪════════╪═══════╪══════════════════════════╡
│ Train Images │ -1.63 │ 10.48 │  -0    │  0

## Hyperparameters

In [161]:
input_dims = data.config()['input_dims']
input_dims

(8, 8, 1)

In [162]:
def get_out_shape(in_shape, k, s, padding):
    in_shape = (1,) + in_shape
    a = np.random.uniform(size = (in_shape))
    dn = jax.lax.conv_dimension_numbers(a.shape, (1,1,k[0],k[1]), ('NHWC', 'IOHW', 'NHWC'))
    out = jax.lax.conv_general_dilated_patches(lhs = a,
                                           filter_shape= k,
                                           window_strides=s,
                                           padding=padding,
                                           dimension_numbers=dn 
                                    )
    return out.shape

In [163]:
initializer = he_uniform()

# Get qlayer sizes
def get_qlayer_sizes(template, n_c, n_l, n_q, k_size):
    if template == 'NQubitPQCSparse':
        return {
            'w': (n_c,n_l, n_q,3,np.prod(k_size)),
            'b': (n_c,n_l,n_q,3,1)
        }
    elif template == 'LightPQC':
        return {
            'w': (n_c,n_l,n_q,np.prod(k_size)),
            'b': (n_c,n_l,n_q,1)
        }
    elif template == 'NQubitPQC':
        assert np.prod(k_size)%3 == 0
        return {
            'w': (n_c,n_l,n_q,np.prod(k_size)),
            'b': (n_c,n_l,n_q,np.prod(k_size))
        }
    elif template == 'Qernel':
        assert n_q == 3
        assert np.prod(k_size)%3 == 0
        
        return {
            'w': (n_c,n_l, n_q, 3),
            'b': (n_c,n_l, n_q, 1),
        }

def random_qlayer_params(size, key, scale=1e-1):
    return initializer(key, size)
    return scale * random.normal(key, size)

def init_qnetwork_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [[random_qlayer_params(size, key) for size, key in zip(sizes.values(), keys)]]
 

# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_clayer_params(m, n, n_c, key, scale=1e-1):
    w_key, b_key = random.split(key)
    return initializer(w_key, (n_c, n,m)), random.normal(b_key, (n_c, n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, n_c, key):
    keys = random.split(key, len(sizes))
    return [random_clayer_params(m, n, n_c, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

kernel_size = args.kernel_size
strides = args.strides
padding = args.padding
clayer_sizes = args.clayer_sizes
num_classes = args.num_classes

template = args.template
n_layers = args.n_layers
n_qubits = args.n_qubits


conv_out_shape = get_out_shape(input_dims, kernel_size, strides, padding)
num_pixels = np.prod(conv_out_shape[:-1])*n_qubits
qlayer_sizes = get_qlayer_sizes(template, num_classes, n_layers, n_qubits, kernel_size)
clayer_sizes = [num_pixels] + clayer_sizes

params = []
params += init_qnetwork_params(qlayer_sizes, random.PRNGKey(0))
# params += init_qnetwork_params(qlayer_sizes, random.PRNGKey(1))
params += init_network_params(clayer_sizes, num_classes, random.PRNGKey(2))

  rhs = jnp.eye(spatial_size, dtype=lhs.dtype).reshape(filter_shape * 2)


In [164]:
for i in params:
    for j in i:
        print(j.shape, end = ' ')
    print()

(5, 1, 1, 9) (5, 1, 1, 9) 
(5, 8, 64) (5, 8) 
(5, 2, 8) (5, 2) 


## QLayers

In [166]:
dev = qml.device('default.qubit.jax', wires=n_qubits)
qubits =list(range(n_qubits))

@jax.jit
@qml.qnode(dev, interface='jax')
def NQubitPQCSparse(inputs, w, b):
    z = jnp.dot(w, jnp.transpose(inputs))+ b

    for q in qubits:
        qml.Hadamard(wires=q)
    
    for l in range(n_layers):
        for q in qubits:
            qml.Rot(z[l,q,0], z[l,q,1], z[l,q,2], wires= q)
        if (l & 1):
            for q0, q1 in zip(qubits[1::2], qubits[2::2] + [qubits[0]]):
                qml.CZ((q0,q1))
        else:
            for q0, q1 in zip(qubits[0::2], qubits[1::2]):
                qml.CZ((q0,q1))
   
    return [qml.expval(qml.PauliZ(q)) for q in qubits]

In [167]:
dev = qml.device('default.qubit.jax', wires=n_qubits)
qubits =list(range(n_qubits))

@jax.jit
@qml.qnode(dev, interface='jax')
def NQubitPQC(inputs, w, b):
    steps = inputs.shape[-1]//3
    for q in qubits:
        qml.Hadamard(wires=q)
    
    for l in range(n_layers):
        for q in qubits:
            for i in range(steps):
                z = jnp.transpose(jnp.multiply(inputs[:,3*i:3*i+3],w[l,q,3*i:3*i+3]) + b[l,q,3*i:3*i+3])
                qml.RZ(z[0], wires=q)
                qml.RY(z[1], wires=q)
                qml.RZ(z[2], wires=q)
                
        if (l & 1):
            for q0, q1 in zip(qubits[1::2], qubits[2::2] + [qubits[0]]):
                qml.CZ((q0,q1))
        else:
            for q0, q1 in zip(qubits[0::2], qubits[1::2]):
                qml.CZ((q0,q1))

    return [qml.expval(qml.PauliZ(q)) for q in qubits]

In [168]:
dev = qml.device('default.qubit.jax', wires=n_qubits)
qubits =list(range(n_qubits))

@jax.jit
@qml.qnode(dev, interface='jax')
def LightPQC(inputs, w, b):
    steps = inputs.shape[-1]//3
    for q in qubits:
        qml.Hadamard(wires=q)
    
    for l in range(n_layers):
        for q in qubits:
            for i in range(steps):
                z = jnp.transpose(jnp.multiply(inputs[:,3*i:3*i+3],w[l,q,3*i:3*i+3]))
                qml.RZ(z[0], wires=q)
                qml.RY(z[1], wires=q)
                qml.RZ(z[2], wires=q)
            qml.RX(b[l,q,0], wires = q)
        if (l & 1):
            for q0, q1 in zip(qubits[1::2], qubits[2::2] + [qubits[0]]):
                qml.CZ((q0,q1))
        else:
            for q0, q1 in zip(qubits[0::2], qubits[1::2]):
                qml.CZ((q0,q1))

    return [qml.expval(qml.PauliZ(q)) for q in qubits]

In [169]:
dev = qml.device('default.qubit.jax', wires=n_qubits)
qubits =list(range(n_qubits))

@jax.jit
@qml.qnode(dev, interface='jax')
def Qernel(inputs, w, b):
    inputs = jnp.transpose(inputs)
    batch_dim = inputs.shape[-1]

    for q in qubits:
        qml.Hadamard(wires=q)
    
    for l in range(n_layers):
        for q in qubits:
            z = jnp.multiply(inputs[3*q:3*q+3, :], jnp.transpose(jnp.tile(w[l,q], (batch_dim,1))))
            qml.Rot(z[0], z[1], z[2], wires= q)
            qml.RX(b[l,q,0], wires = q)
    return [qml.expval(qml.PauliZ(q)) for q in qubits]

In [171]:
def get_nodes(template):
    if template == 'NQubitPQC':
        return [qml.QNode(NQubitPQC) for _ in num_classes]
    elif template == 'LightPQC':
        return [qml.QNode(LightPQC) for _ in num_classes]
    elif template == 'NQubitPQCSparse':
        return [qml.QNode(NQubitPQCSparse) for _ in num_classes]
    elif template == 'Qernel':
        return [qml.QNode(Qernel) for _ in num_classes]

In [135]:
def qconv(x, qweights):
    x = jnp.expand_dims(x,axis=0)
    dn = jax.lax.conv_dimension_numbers(x.shape, 
                                        (1,1,kernel_size[0],kernel_size[1]), 
                                        ('NHWC', 'IOHW', 'NHWC'))
    x = jax.lax.conv_general_dilated_patches(lhs = x,
                                               filter_shape= kernel_size,
                                               window_strides=strides,
                                               padding=padding,
                                               dimension_numbers=dn 
                                              )
    iters = x.shape[1:3]
    x = jnp.reshape(x, (-1, np.prod(kernel_size)))
    
#     x = get_nodes(template)(x, *qweights)
    nodes = get_nodes(template)
    
    
    x = jnp.reshape(x, iters + (n_qubits,))
    return x

In [136]:
dev = qml.device("default.qubit", wires=n_qubits)
qnode = qml.QNode(get_node(template), dev)

inputs = np.random.uniform(size = (10,np.prod(kernel_size)))
weights = params[0]
drawer = qml.draw(qnode, expansion_strategy="device")
print(drawer(inputs,*weights))

0: ──H──RZ(M0)──RY(M1)──RZ(M2)──RZ(M3)──RY(M4)──RZ(M5)──RZ(M6)──RY(M7)──RZ(M8)─┤  <Z>


## Auto-Batching Predictions

In [137]:
from jax.scipy.special import logsumexp

def relu(x):
    return jnp.maximum(0, x)

def forward(params, image):
  # per-example predictions
    activations = qconv(image, params[0])
    activations += image
    activations = relu(activations)
    activations = jnp.reshape(activations, (-1))
    for w, b in params[1:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)

In [138]:
def forwardx(params, image):
  # per-example predictions
    activations = qconv(image, params[0])
    activations = relu(activations)
    activations = qconv(activations, params[1])
    activations += image
    activations = relu(activations)
    
    activations = jnp.reshape(activations, (-1))
    for w, b in params[2:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)

In [139]:
# This works on single examples
random_flattened_image = random.normal(random.PRNGKey(1), input_dims)
random_flattened_image = jnp.floor(random_flattened_image*10)
preds = forward(params,  random_flattened_image)
print(preds)

[   0.      -161.09143]


In [140]:
# Doesn't work with a batch
random_flattened_images = random.normal(random.PRNGKey(1), (2,)+ input_dims)
random_flattened_images = jnp.floor(random_flattened_images*10)
# try:
#     preds = predict(params, random_flattened_images)
# except TypeError:
#     print('Invalid shapes!')

In [141]:
# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_forward = vmap(forward, in_axes=(None,0))

# `batched_predict` has the same call signature as `predict`
batched_preds = batched_forward(params, random_flattened_images)
print(batched_preds)

[[   0.      -165.50987]
 [   0.       -44.39451]]


## Utility and loss functions

In [142]:
from sklearn.metrics import roc_auc_score

def accuracy(y_true, y_pred):
    target_class = jnp.argmax(y_true, axis=1)
    predicted_class = jnp.argmax(y_pred, axis=1)
    return jnp.sum(predicted_class == target_class)
 

def loss_fn(params, images, targets):
    preds = batched_forward(params, images)
    loss_value = -jnp.mean(preds * targets)
    return loss_value, preds

@jit
def update(opt_state, params, x, y):
    _ , grads = jax.value_and_grad(loss_fn, has_aux=True)(params, x, y)
    
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    
    return params, opt_state

@jit
def step(params,x,y):
    loss_value, preds = loss_fn(params, x, y)
    acc = accuracy(y, preds)
    return loss_value, acc

def evaluate(params, ds):
    losses = []
    accs = []
    with tqdm(tfds.as_numpy(ds), unit="batch") as tepoch:
        for x, y in tepoch:
            tepoch.set_description("Validation")
            loss_value, acc = step(params, x, y)
            losses.append(loss_value)
            accs.append(acc)
       
    return jnp.mean(np.array(losses)), jnp.mean(np.array(accs))/args.batch_size

def predict(params, ds):
    preds = []
    y_true = []
    with tqdm(tfds.as_numpy(ds), unit="batch") as tepoch:
        for x, y in tepoch:
            preds += list(batched_forward(params, x))
            y_true += list(y)
    
    return np.array(preds), np.array(y_true)

## Training loop

In [143]:
lr = 1e-3

In [144]:
schedule_fn = optax.linear_schedule(transition_steps=150,
                                    init_value=0.2,
                                    end_value=1e-7,
                                    )
# Defining an optimizer in Jax 
# optimizer = optax.adam(learning_rate=schedule_fn)

print(lr)
# optimizer = optax.adam(learning_rate=args.learning_rate)
optimizer = optax.adam(learning_rate=lr)
opt_state = optimizer.init(params)
# lr = (lr*np.sqrt(0.1))

0.001


In [145]:
import time

# epochs = args.epochs
epochs = 30

epoch_times = []
for epoch in range(epochs):
    start_time = time.time()

    with tqdm(tfds.as_numpy(data.train_ds), unit="batch") as tepoch:
        for x, y in tepoch:
            tepoch.set_description(f"Epoch {epoch}")
            params, opt_state = update(opt_state, params, x, y)
        
    epoch_time = time.time() - start_time
    epoch_times.append(epoch_time)
    
    loss, acc = evaluate(params, data.train_ds)
    val_loss, val_acc = evaluate(params, data.val_ds)
    
    print('loss: {} - acc: {}'.format(loss, acc))
    print('val_loss: {} - val_acc: {}'.format(val_loss, val_acc))
    print('time: {}'.format(epoch_time))
    
    if args.wandb:
        wandb.log({"accuracy": acc, 
                   "val_accuracy": val_acc, 
                   'loss':loss, 
                   'val_loss':val_loss,
                   'lr': lr})


Epoch 0: 100%|████████████████████████████████████████████████████████████████████| 90/90 [00:27<00:00,  3.28batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████████| 90/90 [00:18<00:00,  4.74batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████████| 10/10 [00:09<00:00,  1.10batch/s]


loss: 0.6220170259475708 - acc: 0.640625
val_loss: 0.6115326285362244 - val_acc: 0.632031261920929
time: 27.42134952545166


Epoch 1: 100%|███████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 176.25batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 200.86batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 184.75batch/s]


loss: 0.2198825478553772 - acc: 0.862500011920929
val_loss: 0.19998501241207123 - val_acc: 0.8617187738418579
time: 0.5179848670959473


Epoch 2: 100%|███████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 179.03batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 226.58batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 194.67batch/s]


loss: 0.12094088643789291 - acc: 0.9265625476837158
val_loss: 0.1121101975440979 - val_acc: 0.9281250238418579
time: 0.508918046951294


Epoch 3: 100%|███████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 177.44batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 217.67batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 177.89batch/s]


loss: 0.07760623097419739 - acc: 0.9553819894790649
val_loss: 0.0734054446220398 - val_acc: 0.952343761920929
time: 0.5113320350646973


Epoch 4: 100%|███████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 182.64batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 204.56batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 185.18batch/s]


loss: 0.05574367940425873 - acc: 0.9671875238418579
val_loss: 0.0542498342692852 - val_acc: 0.965624988079071
time: 0.49904465675354004


Epoch 5: 100%|███████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 177.78batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 214.72batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 195.73batch/s]


loss: 0.04259795695543289 - acc: 0.9731771349906921
val_loss: 0.04529106989502907 - val_acc: 0.9710937738418579
time: 0.510899543762207


Epoch 6: 100%|███████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 178.51batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 217.72batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 206.05batch/s]


loss: 0.033732883632183075 - acc: 0.9763889312744141
val_loss: 0.03899955376982689 - val_acc: 0.9781250357627869
time: 0.508903980255127


Epoch 7: 100%|███████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 182.89batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 202.01batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 205.84batch/s]


loss: 0.02736518532037735 - acc: 0.979600727558136
val_loss: 0.03288235887885094 - val_acc: 0.9789062738418579
time: 0.4987175464630127


Epoch 8: 100%|███████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 174.90batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 182.54batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 179.57batch/s]


loss: 0.022491762414574623 - acc: 0.9812500476837158
val_loss: 0.03050808422267437 - val_acc: 0.9820312857627869
time: 0.519294261932373


Epoch 9: 100%|███████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 162.21batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 202.59batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 152.62batch/s]


loss: 0.01890205591917038 - acc: 0.9826388955116272
val_loss: 0.029078776016831398 - val_acc: 0.9828125238418579
time: 0.5617785453796387


Epoch 10: 100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 141.96batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 165.93batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 119.10batch/s]


loss: 0.015894506126642227 - acc: 0.984288215637207
val_loss: 0.027414266020059586 - val_acc: 0.984375
time: 0.6411342620849609


Epoch 11: 100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 166.31batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 173.70batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 182.34batch/s]


loss: 0.013730563223361969 - acc: 0.9847222566604614
val_loss: 0.025252550840377808 - val_acc: 0.985156238079071
time: 0.5507025718688965


Epoch 12: 100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 164.32batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 205.73batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 185.15batch/s]


loss: 0.011901434510946274 - acc: 0.9849826693534851
val_loss: 0.025575727224349976 - val_acc: 0.985156238079071
time: 0.5555939674377441


Epoch 13: 100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 140.51batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 179.49batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 176.47batch/s]


loss: 0.010375948622822762 - acc: 0.9855034947395325
val_loss: 0.023850280791521072 - val_acc: 0.9859375357627869
time: 0.6469278335571289


Epoch 14: 100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 142.03batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 172.46batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 168.91batch/s]


loss: 0.009106083773076534 - acc: 0.9858506917953491
val_loss: 0.02335468865931034 - val_acc: 0.9859375357627869
time: 0.6404075622558594


Epoch 15: 100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 152.77batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 208.73batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 183.83batch/s]


loss: 0.008092625066637993 - acc: 0.9860243201255798
val_loss: 0.022923244163393974 - val_acc: 0.9859375357627869
time: 0.5946457386016846


Epoch 16: 100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 168.07batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 211.02batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 196.64batch/s]


loss: 0.007242409512400627 - acc: 0.9865451455116272
val_loss: 0.023090338334441185 - val_acc: 0.9859375357627869
time: 0.5434188842773438


Epoch 17: 100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 173.31batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 142.81batch/s]
Validation: 100%|█████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 69.45batch/s]


loss: 0.006534216459840536 - acc: 0.9867187738418579
val_loss: 0.022192666307091713 - val_acc: 0.9859375357627869
time: 0.527224063873291


Epoch 18: 100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 171.73batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 198.99batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 198.70batch/s]


loss: 0.0058957431465387344 - acc: 0.9869791865348816
val_loss: 0.022419126704335213 - val_acc: 0.9859375357627869
time: 0.5290374755859375


Epoch 19: 100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 175.11batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 185.38batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 184.56batch/s]


loss: 0.0053279404528439045 - acc: 0.9872395992279053
val_loss: 0.02164122276008129 - val_acc: 0.985156238079071
time: 0.5192840099334717


Epoch 20: 100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 168.44batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 205.06batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 195.29batch/s]


loss: 0.004885667935013771 - acc: 0.987500011920929
val_loss: 0.022207750007510185 - val_acc: 0.985156238079071
time: 0.5388977527618408


Epoch 21: 100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 142.18batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 186.86batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 166.76batch/s]


loss: 0.00452882144600153 - acc: 0.987413227558136
val_loss: 0.022427840158343315 - val_acc: 0.985156238079071
time: 0.6400146484375


Epoch 22: 100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 166.61batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 199.18batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 166.42batch/s]


loss: 0.004210766404867172 - acc: 0.9876736402511597
val_loss: 0.021416908130049706 - val_acc: 0.985156238079071
time: 0.5496385097503662


Epoch 23: 100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 172.73batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 204.16batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 170.22batch/s]


loss: 0.003935822285711765 - acc: 0.9877604246139526
val_loss: 0.021198775619268417 - val_acc: 0.985156238079071
time: 0.5265817642211914


Epoch 24: 100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 167.39batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 187.88batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 150.24batch/s]


loss: 0.003687235526740551 - acc: 0.9878472685813904
val_loss: 0.021047506481409073 - val_acc: 0.985156238079071
time: 0.5432822704315186


Epoch 25: 100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 168.92batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 196.36batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 156.34batch/s]


loss: 0.003461812622845173 - acc: 0.9879340529441833
val_loss: 0.020473478361964226 - val_acc: 0.985156238079071
time: 0.5373961925506592


Epoch 26: 100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 164.22batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 197.61batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 177.18batch/s]


loss: 0.0032600052654743195 - acc: 0.9881076812744141
val_loss: 0.020372727885842323 - val_acc: 0.985156238079071
time: 0.5548381805419922


Epoch 27: 100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 165.27batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 200.39batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 142.17batch/s]


loss: 0.0030817899387329817 - acc: 0.988194465637207
val_loss: 0.0206682737916708 - val_acc: 0.984375
time: 0.5504982471466064


Epoch 28: 100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 171.29batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 195.61batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 195.03batch/s]


loss: 0.002918083220720291 - acc: 0.98828125
val_loss: 0.021269798278808594 - val_acc: 0.984375
time: 0.5307211875915527


Epoch 29: 100%|██████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 165.29batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 196.68batch/s]
Validation: 100%|████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 171.70batch/s]

loss: 0.0027653563302010298 - acc: 0.98828125
val_loss: 0.020011840388178825 - val_acc: 0.984375
time: 0.5533139705657959





In [146]:
from sklearn.metrics import roc_auc_score

out,y_train = predict(params, data.train_ds)
# _, y_test = tf_ds_to_numpy(data.test_ds)
train_auc = roc_auc_score(y_train, out)
train_auc

100%|█████████████████████████████████████████████████████████████████████████████| 90/90 [00:23<00:00,  3.78batch/s]


0.999955275483235

In [147]:
from sklearn.metrics import roc_auc_score

out,y_test = predict(params, data.test_ds)
# _, y_test = tf_ds_to_numpy(data.test_ds)
test_auc = roc_auc_score(y_test, out)
test_auc

100%|█████████████████████████████████████████████████████████████████████████████| 17/17 [00:10<00:00,  1.61batch/s]


0.9993623572777128

In [149]:
if args.wandb:
    wandb.run.summary['test_loss'] = test_loss
    wandb.run.summary['test_acc'] = test_acc
    wandb.run.summary['test_auc'] = test_auc
    wandb.run.summary['train_auc'] = train_auc
    wandb.run.summary['avg_epoch_time'] = np.mean(np.array(epoch_times))
    y = y_test.argmax(axis=1)
    preds = out.argmax(axis=1)
    probs = out
    classes = data.mapping

    roc_curve = wandb.sklearn.plot_roc(y, probs, classes)
    confusion_matrix = wandb.sklearn.plot_confusion_matrix(y, preds, classes)

    wandb.log({"roc_curve": roc_curve})
    wandb.log({"confusion_matrix": confusion_matrix})



In [150]:
if args.wandb:
    wandb.finish()

VBox(children=(Label(value='0.360 MB of 0.360 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.999705…

0,1
accuracy,▁▃▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇██████████████████
loss,█▆▅▄▄▄▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▂▁▁▁▁▁▁▁
lr,████████▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▂▃▄▅▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇███████████████████
val_loss,█▆▄▄▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
accuracy,0.68082
avg_epoch_time,15.47968
loss,0.30514
lr,0.0
test_acc,0.66864
test_auc,0.72296
test_loss,0.31326
train_auc,0.73392
val_accuracy,0.65991
val_loss,0.31339


In [None]:
for i in range(200):
    print(i, schedule_fn(i))