In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

from importlib.util import find_spec
if find_spec("qml_hep_lhc") is None:
    import sys
    sys.path.append('..')

In [31]:
from jax_models.jax_models.models.swin_transformer import SwinTransformer
from qml_hep_lhc.data import ElectronPhoton, MNIST, QuarkGluon
import argparse
from jax import random
import jax.numpy as jnp
import jax
from flax.training import train_state
import optax
from tqdm import tqdm
import time
import numpy as np
import tensorflow_datasets as tfds

In [42]:
args = argparse.Namespace()

# Data
args.center_crop = 0.7
# args.resize = [8, 8]
args.standardize = 1
# args.binary_data = [0,1]
args.dataset_type = '2'
# args.percent_samples = 0.1
args.labels_to_categorical = 1
args.batch_size = 128
args.validation_split = 0.05

In [43]:
data = ElectronPhoton(args)
data.prepare_data()
data.setup()
print(data)

Center cropping...


2022-08-27 20:03:34.268499: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 409600000 exceeds 10% of free system memory.
2022-08-27 20:03:35.176277: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 230400000 exceeds 10% of free system memory.


Center cropping...
Standardizing data...


2022-08-27 20:03:35.756961: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 81920000 exceeds 10% of free system memory.


Converting labels to categorical...
Converting labels to categorical...


2022-08-27 20:03:43.260000: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 218880000 exceeds 10% of free system memory.
2022-08-27 20:03:43.933762: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 218880000 exceeds 10% of free system memory.



Dataset :Electron Photon 2
╒════════╤════════════════════╤═══════════════════╤════════════════════╤═════════════╕
│ Data   │ Train size         │ Val size          │ Test size          │ Dims        │
╞════════╪════════════════════╪═══════════════════╪════════════════════╪═════════════╡
│ X      │ (95000, 24, 24, 1) │ (5000, 24, 24, 1) │ (20000, 24, 24, 1) │ (24, 24, 1) │
├────────┼────────────────────┼───────────────────┼────────────────────┼─────────────┤
│ y      │ (95000, 2)         │ (5000, 2)         │ (20000, 2)         │ (2,)        │
╘════════╧════════════════════╧═══════════════════╧════════════════════╧═════════════╛

╒══════════════╤═══════╤════════╤════════╤═══════╤══════════════════════════╕
│ Type         │   Min │    Max │   Mean │   Std │ Samples for each class   │
╞══════════════╪═══════╪════════╪════════╪═══════╪══════════════════════════╡
│ Train Images │ -2.88 │ 267.37 │      0 │  1    │ [47500, 47500]           │
├──────────────┼───────┼────────┼────────┼───────┼

In [44]:
model = SwinTransformer(
    patch_size = 2,
    emb_dim = 96,
    depths = (2,),
    num_heads = (3,),
    window_size = 2,
    mlp_ratio = 4,
    use_att_bias = True,
    dropout = 0.0,
    att_dropout = 0.0,
    drop_path = 0.1,
    use_abs_pos_emb = False,
    attach_head = True,
    num_classes = 2,
    deterministic = None
)

In [45]:
rng1, rng2, rng3 = random.split(random.PRNGKey(0), 3)

In [46]:
@jax.jit
def apply_model(state, images, labels):
    """Computes gradients, loss and accuracy for a single batch."""
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, 
                                images, 
                                False,
                                rngs={"dropout": rng2, "drop_path": rng3},
                                mutable=["attention_mask", "relative_position_index"]
                               )[0]
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=labels))
        return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == jnp.argmax(labels, -1))
    return grads, loss, accuracy

@jax.jit
def update_model(state, grads):
    return state.apply_gradients(grads=grads)

def evaluate(state, ds):
    losses = []
    accs = []
    
    with tqdm(tfds.as_numpy(ds), unit="batch") as tepoch:
        for x, y in tepoch:
            tepoch.set_description("Evaluation")
            _ , loss, accuracy = apply_model(state, x, y)
            losses.append(loss)
            accs.append(accuracy)
    
    eval_loss = np.mean(losses)
    eval_accuracy = np.mean(accs)
    
    return eval_loss, eval_accuracy

def train_epoch(state, train_ds, epoch):
    epoch_loss = []
    epoch_accuracy = []
    
    start_time = time.time()
    with tqdm(tfds.as_numpy(train_ds), unit="batch") as tepoch:
        for x, y in tepoch:
            tepoch.set_description(f"Epoch {epoch}")
            grads, loss, accuracy = apply_model(state, x, y)
            state = update_model(state, grads)
            epoch_loss.append(loss)
            epoch_accuracy.append(accuracy)
    
    epoch_time = time.time() - start_time
    train_loss = np.mean(epoch_loss)
    train_accuracy = np.mean(epoch_accuracy)
    
    return state, train_loss, train_accuracy, epoch_time

def create_train_state(rng):
    """Creates initial `TrainState`."""
    x = random.normal(random.PRNGKey(0), (1,32,32,1))
    params = model.init( {"params": rng1, "dropout": rng2, "drop_path": rng3}, x, False)['params']
    tx = optax.adam(0.01)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)


def train_and_evaluate(data, epochs = 10):
    """Execute model training and evaluation loop.
    Args:
    config: Hyperparameter configuration for training and evaluation.
    workdir: Directory where the tensorboard summaries are written to.
    Returns:
    The train state (which includes the `.params`).
    """
    rng = jax.random.PRNGKey(0)

    rng, init_rng = jax.random.split(rng)
    state = create_train_state(init_rng)
    
    epoch_times = []
    for epoch in range(1, epochs + 1):
        state, loss, acc, epoch_time = train_epoch(state, data.train_ds, epoch)
        val_loss, val_acc = evaluate(state, data.val_ds)
        
        epoch_times.append(epoch_time)

        print('loss: {} - acc: {}'.format(loss, acc))
        print('val_loss: {} - val_acc: {}'.format(val_loss, val_acc))
        print('time: {}'.format(epoch_time))
    
    print('Avg epoch time: {}'.format(np.mean(epoch_times)))
    return state

In [None]:
state = train_and_evaluate(data)

Epoch 1:  65%|██████████████████████████████████████████▋                       | 480/743 [12:21<07:24,  1.69s/batch]