In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

from importlib.util import find_spec
if find_spec("qml_hep_lhc") is None:
    import sys
    sys.path.append('..')

In [2]:
from qml_hep_lhc.data import ElectronPhoton, MNIST, QuarkGluon
from qml_hep_lhc.data.utils import tf_ds_to_numpy
import argparse
import wandb

import pennylane as qml
import jax.numpy as jnp
import jax
import optax
from jax.nn.initializers import he_uniform
from jax import grad, jit, vmap
from jax import random
import flax.linen as nn
import tensorflow_datasets as tfds
from tqdm import tqdm
import numpy as np
import tensorflow as tf

# Added to silence some warnings.
# from jax.config import config
# config.update("jax_enable_x64", True)

import matplotlib.pyplot as plt
import time

In [3]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [10]:
args = argparse.Namespace()

# Data
args.center_crop = 0.2
# args.resize = [8,8]
args.standardize = 1
# args.power_transform = 1
# args.binary_data = [3,6]
# args.percent_samples = 0.01
# args.processed = 1
args.dataset_type = '3'
args.labels_to_categorical = 1
args.batch_size = 128
args.validation_split = 0.2
# args.graph_conv = 1

# Base Model
args.wandb = False
args.epochs = 10
args.learning_rate = 0.001

# Quantum CNN Parameters
args.n_layers = 1
args.n_qubits = 1
args.template = 'NQubitPQCSparse'
args.initializer = 'he_uniform'

args.num_qconv_layers = 1
args.qconv_dims = [1, 1]
args.kernel_sizes = [(3,3), (3,3)]
args.strides = [(1,1), (1,1)]
args.paddings = ["SAME", "SAME"]

args.clayer_sizes = [8, 2]

In [11]:
if args.wandb:
     wandb.init(project='qml-hep-lhc', config = vars(args))

In [12]:
data = ElectronPhoton(args)
# data.dims = (40,40,2)
data.prepare_data()
data.setup()
print(data)

Center cropping...
Center cropping...
Standardizing data...
Converting labels to categorical...
Converting labels to categorical...

Dataset :Electron Photon 3
╒════════╤═══════════════════╤══════════════════╤══════════════════╤═══════════╕
│ Data   │ Train size        │ Val size         │ Test size        │ Dims      │
╞════════╪═══════════════════╪══════════════════╪══════════════════╪═══════════╡
│ X      │ (320000, 8, 8, 1) │ (80000, 8, 8, 1) │ (98000, 8, 8, 1) │ (8, 8, 1) │
├────────┼───────────────────┼──────────────────┼──────────────────┼───────────┤
│ y      │ (320000, 2)       │ (80000, 2)       │ (98000, 2)       │ (2,)      │
╘════════╧═══════════════════╧══════════════════╧══════════════════╧═══════════╛

╒══════════════╤═══════╤════════╤════════╤═══════╤══════════════════════════╕
│ Type         │   Min │    Max │   Mean │   Std │ Samples for each class   │
╞══════════════╪═══════╪════════╪════════╪═══════╪══════════════════════════╡
│ Train Images │ -2.88 │ 107.37 │     

## Hyperparameters

In [13]:
input_dims = data.config()['input_dims']

In [14]:
input_dims

(8, 8, 1)

In [15]:
def get_out_shape(in_shape,f, k, s, padding):
    in_shape = (1,) + in_shape
    a = np.random.uniform(size = (in_shape))
    dn = jax.lax.conv_dimension_numbers(a.shape, (1,1,k[0],k[1]), ('NHWC', 'IOHW', 'NHWC'))
    out = jax.lax.conv_general_dilated_patches(lhs = a,
                                           filter_shape= k,
                                           window_strides=s,
                                           padding=padding,
                                           dimension_numbers=dn 
                                    )
    return out.shape[1:3]+(f,)

In [16]:
initializer = he_uniform()

# Get qlayer sizes
def get_qlayer_sizes(template, n_l, n_q, k_size):
    if template == 'NQubitPQCSparse':
        return {
            'w': (n_l, n_q,3,np.prod(k_size)),
            'b': (n_l,n_q,3,1)
        }
    elif template == 'NQubitPQC':
        assert np.prod(k_size)%3 == 0
        return {
            'w': (n_l,n_q,np.prod(k_size)),
            'b': (n_l,n_q,np.prod(k_size))
        }

def random_qlayer_params(size, key, filters, n_channels, scale=1e-1):
    w =  initializer(key, size)
    tile_shape = (filters,n_channels,) + (1,)*len(size)
    w = jnp.tile(w, tile_shape)
    return w

def init_qnetwork_params(in_shape, filters, kernel_size, strides, padding, template, n_l, n_q, key):
    n_channels = in_shape[-1]
    sizes = get_qlayer_sizes(template, n_l,n_q, kernel_size)
    keys = random.split(key, len(sizes))
    return [random_qlayer_params(size, key, filters, n_channels) for size, key in zip(sizes.values(), keys)]

# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_clayer_params(m, n, key, scale=1e-1):
    w_key, b_key = random.split(key)
    return initializer(w_key, (n,m)), random.normal(b_key, (n,))
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_clayer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

num_qconv_layers = args.num_qconv_layers
qconv_dims = args.qconv_dims
kernel_sizes = args.kernel_sizes
strides = args.strides
paddings = args.paddings
clayer_sizes = args.clayer_sizes

template = args.template
n_layers = args.n_layers
n_qubits = args.n_qubits


in_shape = input_dims
params = []
for l in range(num_qconv_layers):
    qconv_params = init_qnetwork_params(in_shape, 
                                         qconv_dims[l], 
                                         kernel_sizes[l], 
                                         strides[l], 
                                         paddings[l],
                                         template, 
                                         n_layers,
                                         n_qubits,
                                         random.PRNGKey(l))
    params += [qconv_params]
    in_shape = get_out_shape(in_shape,qconv_dims[l],kernel_sizes[l],strides[l],paddings[l]) 

num_pixels = np.prod(in_shape)//(2**0)
# num_pixels = 16*4*4
clayer_sizes = [num_pixels] + clayer_sizes


params += init_network_params(clayer_sizes, random.PRNGKey(2))

  rhs = jnp.eye(spatial_size, dtype=lhs.dtype).reshape(filter_shape * 2)


In [17]:
for i in params:
    for j in i:
        print(j.shape, end = ' ')
    print()

(1, 1, 1, 1, 3, 9) (1, 1, 1, 1, 3, 1) 
(8, 64) (8,) 
(2, 8) (2,) 


## QLayers

In [18]:
dev = qml.device('default.qubit.jax', wires=n_qubits)
qubits =list(range(n_qubits))

@jax.jit
@qml.qnode(dev, interface='jax')
def NQubitPQCSparse(inputs, w, b):
    z = jnp.dot(w, jnp.transpose(inputs))+ b

    for q in qubits:
        qml.Hadamard(wires=q)
    
    for l in range(n_layers):
        for q in qubits:
            qml.Rot(z[l,q,0], z[l,q,1], z[l,q,2], wires= q)
        if (l & 1):
            for q0, q1 in zip(qubits[1::2], qubits[2::2] + [qubits[0]]):
                qml.CZ((q0,q1))
        else:
            for q0, q1 in zip(qubits[0::2], qubits[1::2]):
                qml.CZ((q0,q1))
   
    return [qml.expval(qml.PauliX(qubits[-1])@ qml.PauliY(qubits[-1])@ qml.PauliZ(qubits[-1]))]
#     return [qml.expval(qml.PauliZ(q)) for q in qubits]

In [19]:
dev = qml.device('default.qubit.jax', wires=n_qubits)
qubits =list(range(n_qubits))

@jax.jit
@qml.qnode(dev, interface='jax')
def NQubitPQC(inputs, w, b):
    steps = inputs.shape[-1]//3
    for q in qubits:
        qml.Hadamard(wires=q)
    
    for l in range(n_layers):
        for q in qubits:
            for i in range(steps):
                z = jnp.transpose(jnp.multiply(inputs[:,3*i:3*i+3],w[l,q,3*i:3*i+3]) + b[l,q,3*i:3*i+3])
                qml.RX(z[0], wires=q)
                qml.RY(z[1], wires=q)
                qml.RZ(z[2], wires=q)
                
        if (l & 1):
            for q0, q1 in zip(qubits[1::2], qubits[2::2] + [qubits[0]]):
                qml.CZ((q0,q1))
        else:
            for q0, q1 in zip(qubits[0::2], qubits[1::2]):
                qml.CZ((q0,q1))

    return qml.expval(qml.PauliZ(qubits[-1]))
#     return [qml.expval(qml.PauliZ(q)) for q in qubits]

In [20]:
def get_node(template):
    if template == 'NQubitPQC':
        return NQubitPQC
    elif template == 'NQubitPQCSparse':
        return NQubitPQCSparse

In [21]:
def qconv_cop(x, w,b):
#     print('cop', x.shape, w.shape, b.shape)
    end_dim = x.shape[-1]
    iters = x.shape[1:3]
    x = jnp.reshape(x , (-1,)+ (end_dim,))
    x = get_node(template)(x, w, b)
    x = jnp.reshape(x, (-1,) + iters)
    return x


batched_qconv_cop = vmap(qconv_cop, in_axes=(3, 0, 0))

def qconv_fop(x, w, b):
#     print('op', x.shape, w.shape, b.shape)
    x = batched_qconv_cop(x,w,b)
#     print('op',x.shape)
    x = jnp.sum(x, axis= 0)
#     print('op',x.shape)
    return x


batched_qconv_fop = vmap(qconv_fop, in_axes=(None,0,0))

def qconv(x, params, filters, kernel_size, stride, padding):
    n_channels = x.shape[-1]
    x = jnp.expand_dims(x,axis=0)
    dn = jax.lax.conv_dimension_numbers(x.shape, 
                                        (1,1,kernel_size[0],kernel_size[1]), 
                                        ('NHWC', 'IOHW', 'NHWC'))
    x = jax.lax.conv_general_dilated_patches(lhs = x,
                                               filter_shape= kernel_size,
                                               window_strides=stride,
                                               padding=padding,
                                               dimension_numbers=dn 
                                              )
    iters = x.shape[1:3]
    x = jnp.reshape(x, ((-1,) + iters + (n_channels,) + (np.prod(kernel_size),)))
#     print('conv',x.shape, params[0].shape, params[1].shape)
    x = batched_qconv_fop(x, params[0], params[1])
#     print('conv',x.shape)
    x = jnp.reshape(x, iters + (filters,))
#     print('conv',x.shape)
    return x

In [22]:
random_flattened_image = random.normal(random.PRNGKey(1), input_dims)
random_flattened_image = jnp.floor(random_flattened_image*10)
random_flattened_image.shape

(8, 8, 1)

In [23]:
out = random_flattened_image
for l in range(num_qconv_layers):
    out = qconv(out, 
                params[l],
                qconv_dims[l], 
                kernel_sizes[l], 
                strides[l], 
                paddings[l])
    print(out.shape)

(8, 8, 1)


In [24]:
dev = qml.device("default.qubit", wires=n_qubits)
qnode = qml.QNode(get_node(template), dev)

inputs = np.random.uniform(size = (10,np.prod(kernel_sizes[0])))
weights = [params[0][0][0][0], params[0][1][0][0]]
drawer = qml.draw(qnode, expansion_strategy="device")
print(drawer(inputs,*weights))

0: ──H──Rot─┤  <X@Y@Z>


## Auto-Batching Predictions

In [25]:
from jax.scipy.special import logsumexp

def relu(x):
    return jnp.maximum(0, x)

def forward(params, image):
  # per-example predictions
    activations = image
    for l in range(num_qconv_layers):
        activations = qconv(activations, params[l], qconv_dims[l], kernel_sizes[l], strides[l], paddings[l])
    activations += image
#     activations = nn.max_pool(activations, window_shape=(2, 2), strides=(2, 2))
    activations = relu(activations)
        
    activations = jnp.reshape(activations, (-1))
    for w, b in params[num_qconv_layers:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)

In [26]:
def forwardvgg(params, image):
  # per-example predictions
    activations = image
    for l in [0,1]:
        activations = qconv(activations, params[l], qconv_dims[l], kernel_sizes[l], strides[l], paddings[l])
    activations = nn.max_pool(activations, window_shape=(2, 2), strides=(2, 2)) 
        
    activations = jnp.reshape(activations, (-1))
    for w, b in params[2:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)

In [27]:
def forwardx(params, image):
  # per-example predictions
    activations = qconv(image, 0)
    activations = qconv(activations, 1)
    activations += image
    activations = relu(activations)
    activations = jnp.reshape(activations, (-1))
    for w, b in params[2:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)

In [28]:
# This works on single examples
preds = forward(params,  random_flattened_image)
print(preds)

[   0.     -157.4177]


In [29]:
# Doesn't work with a batch
random_flattened_images = random.normal(random.PRNGKey(1), (2,)+ input_dims)
random_flattened_images = jnp.floor(random_flattened_images*10)
# try:
#     preds = predict(params, random_flattened_images)
# except TypeError:
#     print('Invalid shapes!')

In [30]:
# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_forward = vmap(forward, in_axes=(None,0))

# `batched_predict` has the same call signature as `predict`
batched_preds = batched_forward(params, random_flattened_images)
print(batched_preds)

[[   0.      -168.53624]
 [   0.       -33.28782]]


## Utility and loss functions

In [31]:
from sklearn.metrics import roc_auc_score

def accuracy(y_true, y_pred):
    target_class = jnp.argmax(y_true, axis=1)
    predicted_class = jnp.argmax(y_pred, axis=1)
    return jnp.mean(predicted_class == target_class)
 

def loss_fn(params, images, targets):
    preds = batched_forward(params, images)
    loss_value = -jnp.mean(preds * targets)
    return loss_value, preds

@jit
def update(opt_state, params, x, y):
    _ , grads = jax.value_and_grad(loss_fn, has_aux=True)(params, x, y)
    
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    
    return params, opt_state 


def step(params,x,y):
    loss_value, preds = loss_fn(params, x, y)
    acc = accuracy(y, preds)
    return loss_value, acc

def evaluate(params, ds):
    losses = []
    accs = []
    with tqdm(tfds.as_numpy(ds), unit="batch") as tepoch:
        for x, y in tepoch:
            tepoch.set_description("Validation")
            loss_value, acc = step(params, x, y)
            losses.append(loss_value)
            accs.append(acc)
       
    return jnp.mean(np.array(losses)), jnp.mean(np.array(accs))

def predict(params, ds):
    preds = []
    y_true = []
    with tqdm(tfds.as_numpy(ds), unit="batch") as tepoch:
        for x, y in tepoch:
            preds += list(batched_forward(params, x))
            y_true += list(y)
    
    return np.array(preds), np.array(y_true)

## Training loop

In [32]:
lr = 1e-4

In [33]:
schedule_fn = optax.linear_schedule(transition_steps=150,
                                    init_value=0.2,
                                    end_value=1e-7,
                                    )
# Defining an optimizer in Jax 
# optimizer = optax.adam(learning_rate=schedule_fn)

print(lr)
optimizer = optax.adam(learning_rate=args.learning_rate)
# optimizer = optax.rmsprop(learning_rate=args.learning_rate)
# optimizer = optax.adam(learning_rate=lr)
opt_state = optimizer.init(params)
lr = (lr*0.1)

0.0001


In [35]:
import time

# epochs = args.epochs
epochs = 5

epoch_times = []
for epoch in range(50):
    start_time = time.time()

    with tqdm(tfds.as_numpy(data.train_ds), unit="batch") as tepoch:
        for x, y in tepoch:
            tepoch.set_description(f"Epoch {epoch}")
            params, opt_state = update(opt_state, params, x, y)
        
    epoch_time = time.time() - start_time
    epoch_times.append(epoch_time)
    
    loss, acc = evaluate(params, data.train_ds)
    val_loss, val_acc = evaluate(params, data.val_ds)
    
    print('loss: {} - acc: {}'.format(loss, acc))
    print('val_loss: {} - val_acc: {}'.format(val_loss, val_acc))
    print('time: {}'.format(epoch_time))
    
    if args.wandb:
        wandb.log({"accuracy": acc, 
                   "val_accuracy": val_acc, 
                   'loss':loss_value, 
                   'val_loss':val_loss})


Epoch 0: 100%|████████████████████████████████| 2500/2500 [00:11<00:00, 220.24batch/s, acc=0.6328125, loss=0.3371101]
100%|████████████████████████████████████████████| 625/625 [00:14<00:00, 44.48batch/s, acc=0.484375, loss=0.34968135]


val_loss: 0.34386634826660156 - val_acc: 0.5708624720573425-  time: 11.354404211044312


Epoch 1: 100%|███████████████████████████████| 2500/2500 [00:11<00:00, 226.35batch/s, acc=0.6484375, loss=0.32705674]
100%|██████████████████████████████████████████████| 625/625 [00:12<00:00, 48.15batch/s, acc=0.53125, loss=0.3478833]


val_loss: 0.3368573784828186 - val_acc: 0.5868749618530273-  time: 11.047217845916748


Epoch 2: 100%|████████████████████████████████| 2500/2500 [00:12<00:00, 207.55batch/s, acc=0.6171875, loss=0.3282447]
100%|████████████████████████████████████████████| 625/625 [00:12<00:00, 49.78batch/s, acc=0.609375, loss=0.32696617]


val_loss: 0.33415260910987854 - val_acc: 0.5941500067710876-  time: 12.047548294067383


Epoch 3: 100%|████████████████████████████████| 2500/2500 [00:11<00:00, 210.01batch/s, acc=0.6015625, loss=0.3258879]
100%|███████████████████████████████████████████| 625/625 [00:12<00:00, 49.74batch/s, acc=0.5859375, loss=0.33604515]


val_loss: 0.3310730457305908 - val_acc: 0.6026874780654907-  time: 11.906599760055542


Epoch 4: 100%|███████████████████████████████████| 2500/2500 [00:11<00:00, 215.38batch/s, acc=0.625, loss=0.32229242]
100%|███████████████████████████████████████████| 625/625 [00:12<00:00, 51.24batch/s, acc=0.5546875, loss=0.35427362]


val_loss: 0.32911941409111023 - val_acc: 0.6096000075340271-  time: 11.609376668930054


Epoch 5: 100%|███████████████████████████████| 2500/2500 [00:10<00:00, 235.30batch/s, acc=0.6484375, loss=0.31647885]
100%|████████████████████████████████████████████████| 625/625 [00:12<00:00, 49.60batch/s, acc=0.625, loss=0.3349622]


val_loss: 0.32762229442596436 - val_acc: 0.6133750081062317-  time: 10.627156734466553


Epoch 6: 100%|███████████████████████████████| 2500/2500 [00:11<00:00, 215.43batch/s, acc=0.6484375, loss=0.31670392]
100%|████████████████████████████████████████████| 625/625 [00:12<00:00, 50.30batch/s, acc=0.578125, loss=0.33989546]


val_loss: 0.3266506493091583 - val_acc: 0.618399977684021-  time: 11.60685396194458


Epoch 7: 100%|███████████████████████████████| 2500/2500 [00:11<00:00, 214.61batch/s, acc=0.6953125, loss=0.31151143]
100%|███████████████████████████████████████████| 625/625 [00:12<00:00, 50.09batch/s, acc=0.6328125, loss=0.33671007]


val_loss: 0.32570984959602356 - val_acc: 0.6222875118255615-  time: 11.651486873626709


Epoch 8: 100%|███████████████████████████████| 2500/2500 [00:10<00:00, 231.85batch/s, acc=0.7265625, loss=0.30377764]
100%|███████████████████████████████████████████| 625/625 [00:12<00:00, 49.57batch/s, acc=0.6796875, loss=0.32563698]


val_loss: 0.3246891498565674 - val_acc: 0.6261749863624573-  time: 10.784868001937866


Epoch 9: 100%|███████████████████████████████| 2500/2500 [00:10<00:00, 233.16batch/s, acc=0.6953125, loss=0.30296236]
100%|████████████████████████████████████████████████| 625/625 [00:12<00:00, 49.41batch/s, acc=0.625, loss=0.3367525]


val_loss: 0.32400721311569214 - val_acc: 0.6287499666213989-  time: 10.724617958068848


Epoch 10: 100%|████████████████████████████████| 2500/2500 [00:11<00:00, 225.83batch/s, acc=0.65625, loss=0.32074615]
100%|███████████████████████████████████████████| 625/625 [00:12<00:00, 50.37batch/s, acc=0.6484375, loss=0.32610428]


val_loss: 0.3233955502510071 - val_acc: 0.6295499801635742-  time: 11.072545766830444


Epoch 11: 100%|█████████████████████████████████| 2500/2500 [00:11<00:00, 222.83batch/s, acc=0.6875, loss=0.31594548]
100%|████████████████████████████████████████████████| 625/625 [00:12<00:00, 50.18batch/s, acc=0.625, loss=0.3287763]


val_loss: 0.3227951228618622 - val_acc: 0.6317625045776367-  time: 11.221354007720947


Epoch 12: 100%|███████████████████████████████| 2500/2500 [00:12<00:00, 206.22batch/s, acc=0.671875, loss=0.32565477]
100%|████████████████████████████████████████████| 625/625 [00:12<00:00, 49.06batch/s, acc=0.671875, loss=0.31423014]


val_loss: 0.32260051369667053 - val_acc: 0.6324124932289124-  time: 12.125293016433716


Epoch 13: 100%|█████████████████████████████████| 2500/2500 [00:10<00:00, 228.31batch/s, acc=0.6875, loss=0.31167102]
100%|████████████████████████████████████████████| 625/625 [00:12<00:00, 49.59batch/s, acc=0.7109375, loss=0.3005457]


val_loss: 0.3219517469406128 - val_acc: 0.6343500018119812-  time: 10.952067852020264


Epoch 14: 100%|████████████████████████████████| 2500/2500 [00:12<00:00, 201.17batch/s, acc=0.7265625, loss=0.301268]
100%|████████████████████████████████████████████| 625/625 [00:12<00:00, 49.20batch/s, acc=0.6640625, loss=0.3135856]


val_loss: 0.3212108910083771 - val_acc: 0.6368749737739563-  time: 12.429492235183716


Epoch 15: 100%|███████████████████████████████| 2500/2500 [00:11<00:00, 219.29batch/s, acc=0.6796875, loss=0.3233631]
100%|████████████████████████████████████████████| 625/625 [00:12<00:00, 49.41batch/s, acc=0.578125, loss=0.33342785]


val_loss: 0.32009920477867126 - val_acc: 0.6407999992370605-  time: 11.403373003005981


Epoch 16: 100%|███████████████████████████████| 2500/2500 [00:12<00:00, 206.08batch/s, acc=0.7109375, loss=0.3123554]
100%|███████████████████████████████████████████████| 625/625 [00:12<00:00, 50.10batch/s, acc=0.625, loss=0.32881182]


val_loss: 0.3192066252231598 - val_acc: 0.6428874731063843-  time: 12.133248567581177


Epoch 17: 100%|██████████████████████████████| 2500/2500 [00:11<00:00, 225.24batch/s, acc=0.6484375, loss=0.32645345]
100%|████████████████████████████████████████████| 625/625 [00:12<00:00, 48.49batch/s, acc=0.546875, loss=0.32869756]


val_loss: 0.3181924521923065 - val_acc: 0.6444374918937683-  time: 11.101480484008789


Epoch 18: 100%|██████████████████████████████| 2500/2500 [00:12<00:00, 207.35batch/s, acc=0.6796875, loss=0.31404164]
100%|████████████████████████████████████████████| 625/625 [00:12<00:00, 49.76batch/s, acc=0.609375, loss=0.32327658]


val_loss: 0.31773829460144043 - val_acc: 0.645562469959259-  time: 12.059669494628906


Epoch 19: 100%|████████████████████████████████| 2500/2500 [00:11<00:00, 208.75batch/s, acc=0.703125, loss=0.3123882]
100%|████████████████████████████████████████████| 625/625 [00:12<00:00, 49.80batch/s, acc=0.6015625, loss=0.3178433]


val_loss: 0.3170745372772217 - val_acc: 0.6488249897956848-  time: 11.978140354156494


Epoch 20: 100%|████████████████████████████████| 2500/2500 [00:11<00:00, 227.18batch/s, acc=0.65625, loss=0.32491842]
100%|█████████████████████████████████████████████| 625/625 [00:12<00:00, 49.29batch/s, acc=0.640625, loss=0.3153455]


val_loss: 0.3164973258972168 - val_acc: 0.6494874954223633-  time: 11.006595849990845


Epoch 21: 100%|██████████████████████████████| 2500/2500 [00:11<00:00, 209.26batch/s, acc=0.6640625, loss=0.31085375]
100%|███████████████████████████████████████████████| 625/625 [00:12<00:00, 50.25batch/s, acc=0.625, loss=0.31242284]


val_loss: 0.3159262239933014 - val_acc: 0.650362491607666-  time: 11.949259042739868


Epoch 22: 100%|██████████████████████████████| 2500/2500 [00:12<00:00, 206.05batch/s, acc=0.6171875, loss=0.32784417]
100%|████████████████████████████████████████████| 625/625 [00:12<00:00, 49.51batch/s, acc=0.609375, loss=0.31729382]


val_loss: 0.31534653902053833 - val_acc: 0.6519374847412109-  time: 12.13486647605896


Epoch 23: 100%|██████████████████████████████| 2500/2500 [00:12<00:00, 200.61batch/s, acc=0.6796875, loss=0.31675914]
100%|███████████████████████████████████████████| 625/625 [00:12<00:00, 48.95batch/s, acc=0.6328125, loss=0.32474944]


val_loss: 0.31508633494377136 - val_acc: 0.6520000100135803-  time: 12.46433401107788


Epoch 24: 100%|██████████████████████████████| 2500/2500 [00:11<00:00, 218.66batch/s, acc=0.6640625, loss=0.31730282]
100%|█████████████████████████████████████████████| 625/625 [00:12<00:00, 49.47batch/s, acc=0.65625, loss=0.31044957]


val_loss: 0.31448492407798767 - val_acc: 0.6540125012397766-  time: 11.435445785522461


Epoch 25: 100%|███████████████████████████████| 2500/2500 [00:12<00:00, 207.93batch/s, acc=0.640625, loss=0.33016458]
100%|███████████████████████████████████████████| 625/625 [00:12<00:00, 49.59batch/s, acc=0.6640625, loss=0.31019282]


val_loss: 0.314273476600647 - val_acc: 0.6540499925613403-  time: 12.025696754455566


Epoch 26: 100%|██████████████████████████████████| 2500/2500 [00:11<00:00, 213.33batch/s, acc=0.625, loss=0.32331103]
100%|█████████████████████████████████████████████| 625/625 [00:12<00:00, 49.25batch/s, acc=0.65625, loss=0.32201824]


val_loss: 0.31402650475502014 - val_acc: 0.654449999332428-  time: 11.721124649047852


Epoch 27: 100%|█████████████████████████████████| 2500/2500 [00:11<00:00, 211.75batch/s, acc=0.65625, loss=0.3068384]
100%|███████████████████████████████████████████| 625/625 [00:12<00:00, 48.25batch/s, acc=0.6640625, loss=0.32086313]


val_loss: 0.31355875730514526 - val_acc: 0.6550999879837036-  time: 11.808548927307129


Epoch 28: 100%|██████████████████████████████| 2500/2500 [00:11<00:00, 219.12batch/s, acc=0.5859375, loss=0.34343916]
100%|███████████████████████████████████████████| 625/625 [00:12<00:00, 51.10batch/s, acc=0.6484375, loss=0.31584525]


val_loss: 0.31314244866371155 - val_acc: 0.655662477016449-  time: 11.41171669960022


Epoch 29: 100%|████████████████████████████████| 2500/2500 [00:11<00:00, 219.36batch/s, acc=0.578125, loss=0.3258404]
100%|███████████████████████████████████████████████| 625/625 [00:12<00:00, 50.14batch/s, acc=0.625, loss=0.32827088]


val_loss: 0.313271701335907 - val_acc: 0.6557624936103821-  time: 11.398970365524292


Epoch 30: 100%|██████████████████████████████| 2500/2500 [00:12<00:00, 206.94batch/s, acc=0.6015625, loss=0.31267154]
100%|████████████████████████████████████████████| 625/625 [00:12<00:00, 48.70batch/s, acc=0.6484375, loss=0.3128515]


val_loss: 0.3127617835998535 - val_acc: 0.6560249924659729-  time: 12.082613706588745


Epoch 31: 100%|███████████████████████████████████| 2500/2500 [00:11<00:00, 213.33batch/s, acc=0.625, loss=0.3098627]
100%|███████████████████████████████████████████| 625/625 [00:12<00:00, 49.42batch/s, acc=0.6640625, loss=0.32770076]


val_loss: 0.3126322627067566 - val_acc: 0.6554374694824219-  time: 11.721356391906738


Epoch 32: 100%|███████████████████████████████| 2500/2500 [00:11<00:00, 221.12batch/s, acc=0.640625, loss=0.31413162]
100%|████████████████████████████████████████████| 625/625 [00:12<00:00, 48.22batch/s, acc=0.671875, loss=0.31919086]


val_loss: 0.31252726912498474 - val_acc: 0.6563875079154968-  time: 11.308299779891968


Epoch 33: 100%|███████████████████████████████| 2500/2500 [00:10<00:00, 227.58batch/s, acc=0.6328125, loss=0.3241584]
100%|███████████████████████████████████████████| 625/625 [00:12<00:00, 49.69batch/s, acc=0.6484375, loss=0.31599292]


val_loss: 0.3127530515193939 - val_acc: 0.6562874913215637-  time: 10.987435102462769


Epoch 34: 100%|██████████████████████████████| 2500/2500 [00:12<00:00, 199.06batch/s, acc=0.6484375, loss=0.31176794]
100%|███████████████████████████████████████████████| 625/625 [00:12<00:00, 48.67batch/s, acc=0.6875, loss=0.3144514]


val_loss: 0.31249314546585083 - val_acc: 0.6559749841690063-  time: 12.561524391174316


Epoch 35: 100%|██████████████████████████████| 2500/2500 [00:12<00:00, 205.52batch/s, acc=0.6015625, loss=0.33167154]
100%|████████████████████████████████████████████| 625/625 [00:12<00:00, 48.54batch/s, acc=0.671875, loss=0.33575442]


val_loss: 0.31255316734313965 - val_acc: 0.6561374664306641-  time: 12.166916608810425


Epoch 36: 100%|██████████████████████████████| 2500/2500 [00:12<00:00, 204.89batch/s, acc=0.6015625, loss=0.31038058]
100%|█████████████████████████████████████████████| 625/625 [00:12<00:00, 49.16batch/s, acc=0.59375, loss=0.32829404]


val_loss: 0.3125649392604828 - val_acc: 0.656125009059906-  time: 12.2040274143219


Epoch 37: 100%|██████████████████████████████| 2500/2500 [00:12<00:00, 207.18batch/s, acc=0.6484375, loss=0.31203285]
100%|██████████████████████████████████████████████| 625/625 [00:12<00:00, 48.50batch/s, acc=0.65625, loss=0.3187669]


val_loss: 0.3125206232070923 - val_acc: 0.6562624573707581-  time: 12.06918716430664


Epoch 38: 100%|████████████████████████████████| 2500/2500 [00:11<00:00, 213.65batch/s, acc=0.65625, loss=0.31638584]
100%|█████████████████████████████████████████████| 625/625 [00:12<00:00, 50.41batch/s, acc=0.65625, loss=0.31725177]


val_loss: 0.31258925795555115 - val_acc: 0.6561625003814697-  time: 11.70347809791565


Epoch 39: 100%|██████████████████████████████| 2500/2500 [00:11<00:00, 225.51batch/s, acc=0.6015625, loss=0.32134098]
100%|█████████████████████████████████████████████| 625/625 [00:12<00:00, 50.16batch/s, acc=0.640625, loss=0.3202834]


val_loss: 0.3123466372489929 - val_acc: 0.6562125086784363-  time: 11.088484764099121


Epoch 40: 100%|██████████████████████████████| 2500/2500 [00:11<00:00, 210.42batch/s, acc=0.6015625, loss=0.32454735]
100%|██████████████████████████████████████████████| 625/625 [00:12<00:00, 49.63batch/s, acc=0.6875, loss=0.29677352]


val_loss: 0.3122037351131439 - val_acc: 0.6567249894142151-  time: 11.883696556091309


Epoch 41: 100%|█████████████████████████████████| 2500/2500 [00:12<00:00, 203.86batch/s, acc=0.65625, loss=0.3019056]
100%|███████████████████████████████████████████| 625/625 [00:12<00:00, 49.41batch/s, acc=0.6328125, loss=0.31366307]


val_loss: 0.3119577467441559 - val_acc: 0.6568124890327454-  time: 12.265138864517212


Epoch 42: 100%|██████████████████████████████| 2500/2500 [00:11<00:00, 217.38batch/s, acc=0.6796875, loss=0.30310538]
100%|████████████████████████████████████████████| 625/625 [00:12<00:00, 49.89batch/s, acc=0.671875, loss=0.31501904]


val_loss: 0.31222042441368103 - val_acc: 0.6568250060081482-  time: 11.503151178359985


Epoch 43: 100%|███████████████████████████████| 2500/2500 [00:12<00:00, 203.39batch/s, acc=0.6484375, loss=0.3065328]
100%|███████████████████████████████████████████████| 625/625 [00:12<00:00, 48.33batch/s, acc=0.6875, loss=0.3112695]


val_loss: 0.3119654953479767 - val_acc: 0.6572374701499939-  time: 12.293892860412598


Epoch 44: 100%|██████████████████████████████████| 2500/2500 [00:11<00:00, 216.13batch/s, acc=0.59375, loss=0.328904]
100%|████████████████████████████████████████████| 625/625 [00:12<00:00, 50.93batch/s, acc=0.6640625, loss=0.3121731]


val_loss: 0.31202465295791626 - val_acc: 0.6567999720573425-  time: 11.569178581237793


Epoch 45: 100%|██████████████████████████████| 2500/2500 [00:11<00:00, 222.06batch/s, acc=0.6640625, loss=0.30028874]
100%|████████████████████████████████████████████| 625/625 [00:12<00:00, 49.24batch/s, acc=0.671875, loss=0.30906916]


val_loss: 0.3118770718574524 - val_acc: 0.6574249863624573-  time: 11.260501861572266


Epoch 46: 100%|████████████████████████████████| 2500/2500 [00:10<00:00, 230.35batch/s, acc=0.65625, loss=0.30441967]
100%|███████████████████████████████████████████| 625/625 [00:12<00:00, 49.60batch/s, acc=0.7109375, loss=0.30255795]


val_loss: 0.3117813766002655 - val_acc: 0.6573874950408936-  time: 10.855542182922363


Epoch 47: 100%|██████████████████████████████████| 2500/2500 [00:12<00:00, 208.15batch/s, acc=0.625, loss=0.31163877]
100%|████████████████████████████████████████████| 625/625 [00:12<00:00, 50.01batch/s, acc=0.703125, loss=0.30824646]


val_loss: 0.3116956651210785 - val_acc: 0.6580249667167664-  time: 12.013102293014526


Epoch 48: 100%|███████████████████████████████| 2500/2500 [00:11<00:00, 212.23batch/s, acc=0.671875, loss=0.31091428]
100%|███████████████████████████████████████████| 625/625 [00:12<00:00, 49.55batch/s, acc=0.7109375, loss=0.39424405]


val_loss: 0.3117566704750061 - val_acc: 0.6581999659538269-  time: 11.78219223022461


Epoch 49: 100%|███████████████████████████████| 2500/2500 [00:12<00:00, 198.15batch/s, acc=0.640625, loss=0.30931842]
100%|███████████████████████████████████████████| 625/625 [00:12<00:00, 49.07batch/s, acc=0.6796875, loss=0.30606097]


val_loss: 0.31214311718940735 - val_acc: 0.6580125093460083-  time: 12.61914587020874


In [89]:
test_loss, test_acc = evaluate(params, data.test_ds)
test_loss, test_acc

100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 383/383 [00:13<00:00, 28.59batch/s, acc=0.7163462, loss=0.289188]


(DeviceArray(0.29829407, dtype=float32), DeviceArray(0.6901353, dtype=float32))

In [None]:
from sklearn.metrics import roc_auc_score

out,y_train = predict(params, data.train_ds)
# _, y_test = tf_ds_to_numpy(data.test_ds)
train_auc = roc_auc_score(y_train, out)
train_auc

In [90]:
from sklearn.metrics import roc_auc_score

out,y_test = predict(params, data.test_ds)
# _, y_test = tf_ds_to_numpy(data.test_ds)
test_auc = roc_auc_score(y_test, out)
test_auc

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 383/383 [00:15<00:00, 24.20batch/s]


0.7444356754477301

In [207]:
if args.wandb:
    wandb.run.summary['test_loss'] = test_loss
    wandb.run.summary['test_acc'] = test_acc
    wandb.run.summary['test_auc'] = test_auc
    wandb.run.summary['train_auc'] = train_auc
    wandb.run.summary['avg_epoch_time'] = np.mean(np.array(epoch_times))
    y = y_test.argmax(axis=1)
    preds = out.argmax(axis=1)
    probs = out
    classes = data.mapping

    roc_curve = wandb.sklearn.plot_roc(y, probs, classes)
    confusion_matrix = wandb.sklearn.plot_confusion_matrix(y, preds, classes)

    wandb.log({"roc_curve": roc_curve})
    wandb.log({"confusion_matrix": confusion_matrix})



In [208]:
if args.wandb:
    wandb.finish()

VBox(children=(Label(value='0.360 MB of 0.360 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁▆▅▄▆▆▆▄▇▅▇▆█▆▃▆▄▇▅▇▄▆▆▇▇▇▅▇▇▆▇▄▄▇▅▆▆▅▆▆
loss,█▅▅▇▃▄▄▆▄▅▃▃▁▃▅▄▅▂▆▃▅▄▅▄▂▃▅▂▁▃▂▆▄▃▅▄▃▄▃▂
val_accuracy,▁▅▆▇▇▇▇▇▇▇█▇████████████████████████████
val_loss,█▄▃▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
accuracy,0.66667
loss,0.30837
test_acc,0.68754
test_auc,0.74154
test_loss,0.30152
val_accuracy,0.68685
val_loss,0.3013


In [32]:
for i in range(200):
    print(i, schedule_fn(i))

0 0.2
1 0.19866668
2 0.19733334
3 0.19600001
4 0.19466668
5 0.19333333
6 0.192
7 0.19066668
8 0.18933333
9 0.18800001
10 0.18666668
11 0.18533334
12 0.18400002
13 0.18266669
14 0.18133333
15 0.18
16 0.17866668
17 0.17733334
18 0.17600001
19 0.17466669
20 0.17333335
21 0.17200002
22 0.1706667
23 0.16933335
24 0.16800003
25 0.16666669
26 0.16533335
27 0.16400002
28 0.1626667
29 0.16133335
30 0.16000003
31 0.15866669
32 0.15733334
33 0.15600002
34 0.15466669
35 0.15333335
36 0.15200002
37 0.1506667
38 0.14933336
39 0.14800003
40 0.1466667
41 0.14533336
42 0.14400004
43 0.14266671
44 0.14133336
45 0.14000003
46 0.1386667
47 0.13733336
48 0.13600004
49 0.1346667
50 0.13333336
51 0.13200003
52 0.1306667
53 0.12933336
54 0.12800004
55 0.12666671
56 0.12533337
57 0.124000035
58 0.1226667
59 0.121333376
60 0.12000004
61 0.11866671
62 0.11733337
63 0.116000034
64 0.1146667
65 0.113333374
66 0.11200004
67 0.1106667
68 0.109333366
69 0.10800003
70 0.10666671
71 0.10533337
72 0.10400004
73 0.102666