In [20]:
# Requirements
!pip install tensorflow_datasets flax jax optax tqdm netket imgaug==0.2.6

In [19]:
import jax
import jax.numpy as jnp
import numpy as np
from matplotlib import pyplot as plt
import tensorflow_datasets as tfds
import jax
import jax.numpy as jnp
import jax
import flax
import netket as nk
import jax.numpy as jnp
from jax.experimental import stax
import optax
from flax.training import train_state  # Useful dataclass to keep train state
from functools import partial
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

### Utility Functions

In [21]:
def show_img(img, ax=None, title=None):
  """Shows a single image.
  
  Must be stored as a 3d-tensor where the last dimension is 1 channel (greyscale)
  """
  if ax is None:
    ax = plt.gca()
  ax.imshow(img)
  ax.set_xticks([])
  ax.set_yticks([])
  if title:
    ax.set_title(title)

def show_img_grid(imgs, titles):
  """Shows a grid of images."""
  n = int(np.ceil(len(imgs)**.5))
  _, axs = plt.subplots(n, n, figsize=(3 * n, 3 * n))
  for i, (img, title) in enumerate(zip(imgs, titles)):
    show_img(img, axs[i // n][i % n], title)

## Setting up the dataset

In [22]:
#import
ds_builder = tfds.builder('cifar10')
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))

#delete the id 
del train_ds['id']
del test_ds['id']


# Normalize
train_ds['image'] = jnp.float32(train_ds['image']) / 255.
test_ds['image'] = jnp.float32(test_ds['image']) / 255.
print("dataset keys:", train_ds.keys())
print(f"The training dataset has shape: {train_ds['image'].shape} and dtype {train_ds['image'].dtype}")
print(f"The test     dataset has shape: {test_ds['image'].shape} and dtype {train_ds['image'].dtype}")
print("")
print(f"The training labels have shape: {train_ds['label'].shape} and dtype {train_ds['label'].dtype}")
print(f"The test     labels have shape: {test_ds['label'].shape} and dtype {test_ds['label'].dtype}")
print("The mean     of the data stored in the images are: ", np.mean(train_ds['image']))
print("The variance of the data stored in the images are: ", np.var(train_ds['image']))

dataset keys: dict_keys(['image', 'label'])
The training dataset has shape: (50000, 32, 32, 3) and dtype float32
The test     dataset has shape: (10000, 32, 32, 3) and dtype float32

The training labels have shape: (50000,) and dtype int64
The test     labels have shape: (10000,) and dtype int64
The mean     of the data stored in the images are:  0.45392698
The variance of the data stored in the images are:  0.059536036


##Non-linearities

In [23]:
def modRelu(z, bias): # relu(|z|+b) * (z / |z|)
    norm = jnp.abs(z)
    scale = nk.nn.relu(norm + bias) / (norm + 1e-6)
    scaled = jax.lax.complex(jnp.real(z)*scale, jnp.imag(z)*scale)
    return scaled
def complex_relu(z):
    return jnp.where(z.real > 0, z, 0)
complex_relu = jax.jit(complex_relu)
modRelu=jax.jit(modRelu)

## DropoutComment

 /!\ You can't use max_pool with complex numbers
 /!\ To apply dropout you need to add a parameter to init and apply :

logits = model.apply({'params': params}, inputs=inputs, rngs={'dropout': dropout_rng})
pars = model.init({'params': jax.random.PRNGKey(seed), 'dropout': jax.random.PRNGKey(seed_dropout)}, sample_input)

 But be careful, RNG is hell with jax but before wondering about dropout we need to make this simple model work


## Define Model

In [24]:
class Model(nk.nn.Module):
  n_classes : int = 10
  @nk.nn.compact
  def __call__(self, x, train):
    #make rng for dropout
    dropout_rng = self.make_rng('dropout')

    #first TWO convolutions 3x3 --> 32. Avgpol 2x2 stride 2x2. Dropout 0.2
    x = nk.nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nk.nn.relu(x)
    x = nk.nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nk.nn.relu(x)
    x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = flax.linen.Dropout(0.2, deterministic=not train)(x)

    #second TWO convolutions 64x64 --> 32. Avgpol 2x2 stride 2x2. Dropout 0.2
    x = nk.nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nk.nn.relu(x)
    x = nk.nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nk.nn.relu(x)
    x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = flax.linen.Dropout(0.2, deterministic=not train)(x)

    #third TWO convolutions 128x128 --> 32. Avgpol 2x2 stride 2x2. Dropout 0.2
    x = nk.nn.Conv(features=128, kernel_size=(3, 3))(x)
    x = nk.nn.relu(x)
    x = nk.nn.Conv(features=128, kernel_size=(3, 3))(x)
    x = nk.nn.relu(x)
    x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = flax.linen.Dropout(0.2, deterministic=not train)(x)


    #flatten and dense 1
    x = x.reshape((x.shape[0], -1)) # Flatten
    x = nk.nn.Dense(features=128)(x)
    x = nk.nn.relu(x)

    #DROPOUT2
    #x = flax.linen.Dropout(0.5, deterministic=not train)(x) 

    #dense2 and softmax
    x = nk.nn.Dense(features=10)(x)
    x = nk.nn.log_softmax(x)
    return x


## Loss functions

In [25]:
# The loss function that we will use
def cross_entropy(*, logits, labels):
    one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
    return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))

In [26]:
dropout_rng, init_dropout = jax.random.split(jax.random.PRNGKey(1))


In [27]:
def loss_fn(params, dropout_rng, images, labels):
    """
    Loss function minimised during training of the model.
    """
    logits = model.apply(params, images, rngs={'dropout' : dropout_rng}, train=True)
    return cross_entropy(logits=logits, labels=labels)

def compute_metrics(*, logits, labels):
    """
    Compute metrics of the model during training.
    
    Returns the loss and the accuracy.
    """
    loss = cross_entropy(logits=logits, labels=labels)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    metrics = {
      'loss': loss,
      'accuracy': accuracy,
    }
    return metrics

## 4 - Create the setup and training loop

In [28]:
def create_train_state(example, rng, optimiser, dropout_rng):
    """Creates initial `TrainState`, holding the current parameters, state of the
    optimiser and other metadata.
    """
    # Construct the model parameters
    params = model.init({'params' : rng, 'dropout' : dropout_rng}, example, train=True)
        
    # Package all those informations in the model state
    return train_state.TrainState.create(
        apply_fn=model.apply, params=params, tx=optimiser)


In [29]:
@jax.jit
def eval_metrics(params, batch, dropout_rng):
    """
    This function evaluates the metrics without training the model.
    
    Used to check the performance of the network on training and test datasets.
    """
    logits = model.apply(params, batch['image'], rngs={'dropout' : dropout_rng}, train=False)
    return compute_metrics(logits=logits, labels=batch['label'])

In [30]:
def fgsm_update(image, data_grad, update_max_norm):
    """
    Compute the FGSM update on an image (or a batch of images)

    @param image: float32 tensor of shape (batch_size, rgb, height, width)
    @param data_grad: float32 tensor of the same shape as `image`. Gradient of the loss with respect to `image`.
    @param update_max_norm: float, the maximum permitted difference between `image` and the output of this function measured in L_inf norm.

    @returns a perturbed version of `image` with the same shape
    """
    # Collect the element-wise sign of the data gradient
    sign_data_grad = jnp.sign(data_grad)
    # Create the perturbed image by adjusting each pixel of the input image
    perturbed_image = image + update_max_norm*sign_data_grad
    # Adding clipping to maintain [0,1] range
    perturbed_image = jnp.clip(perturbed_image, 0, 1)
    # Return the perturbed image
    return perturbed_image

In [31]:
def loss_fn2(images, params, dropout_rng, labels):
    """
    Loss function minimised during training of the model.
    """
    logits = model.apply(params, images, rngs={'dropout' : dropout_rng}, train=True)
    return cross_entropy(logits=logits, labels=labels)


In [32]:
@jax.jit
def train_step(state, batch, dropout_rng):
    """
    Train for a single step.
    
    """
    #Make parameters the only 'free' parameter
    _loss_fn = partial(loss_fn, dropout_rng = dropout_rng, images=batch['image'], labels=batch['label'])
    # construct the function returning the loss value and gradient.
    val_grad_fn = jax.value_and_grad(_loss_fn)
    # compute loss and gradient
    loss, grads = val_grad_fn(state.params)
    grads = jax.tree_map(lambda x: x.conj(), grads) 
    # update the state parameters 
    state = state.apply_gradients(grads=grads)

    metrics = eval_metrics(state.params, batch, dropout_rng)
    
    return state, metrics

In [33]:
def train_epoch(state, train_ds, batch_size, epoch, rng, dropout_rng, *, max_steps=None):
    """Train for a single `epoch`.
    
    And epoch is composed of several steps, where every step is taken by updating
    the network parameters with a small mini-batch.
    """
    
    # total number of training images
    train_ds_size = len(train_ds['image'])
    
    # Compute how many steps are present in this epoch.
    steps_per_epoch = train_ds_size // batch_size

    # Truncate the number of steps
    if max_steps is not None:
        steps_per_epoch = min(steps_per_epoch, max_steps)

    # generate a random permutation of the indices to shuffle the training
    perms = jax.random.permutation(rng, train_ds_size)
    perms = perms[:steps_per_epoch * batch_size]
    perms = perms.reshape((steps_per_epoch, batch_size))
    
    # execute the training step for every mini-batch
    batch_metrics = []
    for perm in perms:
        batch = {k: v[perm, ...] for k, v in train_ds.items()}
        state, metrics = train_step(state, batch, dropout_rng)
        batch_metrics.append(metrics)

    # compute mean of metrics across each batch in epoch.
    batch_metrics_np = jax.device_get(batch_metrics)
    epoch_metrics_np = {
        k: np.mean([metrics[k] for metrics in batch_metrics_np])
            for k in batch_metrics_np[0]}

    return state, epoch_metrics_np


def evaluate_model(params, test_ds, dropout_rng):
    """
    evaluate the performance of the model on the test dataset
    """
    metrics = eval_metrics(params, test_ds, dropout_rng)
    metrics = jax.device_get(metrics)
    summary = jax.tree_map(lambda x: x.item(), metrics)
    return summary['loss'], summary['accuracy']

# 5 - Training model

In [None]:
models_saved = []
for s in range(5):
  # Definition of optimiser HyperParameters
  learning_rate = 0.005
  momentum = 0.9
  optimiser = optax.sgd(learning_rate, momentum)
  num_epochs = 80
  batch_size = 250
  max_steps = 200

  #define rngs
  seed = s #123
  seed_dropout = 10-s #0
  key = {'params': jax.random.PRNGKey(seed), 'dropout': jax.random.PRNGKey(seed_dropout)}

  #init model
  model = Model(n_classes=10)
  sample_input = jnp.ones([1, 32, 32, 3])
  pars = model.init(key, sample_input, train=True)

  # Split the rng to get two keys, one to 'shuffle' the dataset at every iteration,
  # and one to initialise the network
  rng, init_rng = jax.random.split(jax.random.PRNGKey(s))

  state = create_train_state(sample_input, init_rng, optimiser, init_dropout)
  metrics = {"test_loss" : [], "test_accuracy": [], "train_loss":[], "train_accuracy":[]}
  with tqdm(range(1, num_epochs + 1)) as pbar:
      for epoch in pbar:
          # Use a separate PRNG key to permute image data during shuffling
          rng, input_rng = jax.random.split(rng)
          dropout_rng, _ = jax.random.split(dropout_rng)
          # Run an optimization step over a training batch
          state, train_metrics = train_epoch(state, train_ds, batch_size, epoch, input_rng, dropout_rng)
          
          # Evaluate on the test set after each training epoch
          test_loss, test_accuracy = evaluate_model(state.params, test_ds, dropout_rng)
          pbar.write('train epoch: %d, loss: %.4f, accuracy: %.2f' % (epoch, train_metrics['loss'], train_metrics['accuracy'] * 100))
          pbar.write(' test epoch: %d, loss: %.2f, accuracy: %.2f' % (epoch, test_loss, test_accuracy * 100))

          # save data
          metrics["train_loss"].append(train_metrics["loss"])
          metrics["train_accuracy"].append(train_metrics["accuracy"])
          metrics["test_loss"].append(test_loss)
          metrics["test_accuracy"].append(test_accuracy)

  models_saved += [state]

    

  0%|          | 0/80 [00:00<?, ?it/s]

train epoch: 1, loss: 2.1345, accuracy: 19.30
 test epoch: 1, loss: 1.92, accuracy: 30.85
train epoch: 2, loss: 1.8135, accuracy: 34.44
 test epoch: 2, loss: 1.71, accuracy: 38.87
train epoch: 3, loss: 1.6614, accuracy: 40.02
 test epoch: 3, loss: 1.58, accuracy: 43.02
train epoch: 4, loss: 1.5468, accuracy: 44.01
 test epoch: 4, loss: 1.52, accuracy: 44.58
train epoch: 5, loss: 1.4688, accuracy: 46.93
 test epoch: 5, loss: 1.43, accuracy: 48.02
train epoch: 6, loss: 1.4060, accuracy: 49.07
 test epoch: 6, loss: 1.38, accuracy: 50.59
train epoch: 7, loss: 1.3472, accuracy: 51.39
 test epoch: 7, loss: 1.35, accuracy: 52.22
train epoch: 8, loss: 1.2701, accuracy: 54.43
 test epoch: 8, loss: 1.25, accuracy: 55.05
train epoch: 9, loss: 1.2035, accuracy: 57.02
 test epoch: 9, loss: 1.27, accuracy: 55.12
train epoch: 10, loss: 1.1543, accuracy: 58.95
 test epoch: 10, loss: 1.15, accuracy: 58.87
train epoch: 11, loss: 1.1046, accuracy: 60.88
 test epoch: 11, loss: 1.11, accuracy: 60.13
train 

  0%|          | 0/80 [00:00<?, ?it/s]

train epoch: 1, loss: 2.1768, accuracy: 16.81
 test epoch: 1, loss: 1.95, accuracy: 29.59
train epoch: 2, loss: 1.8676, accuracy: 32.25
 test epoch: 2, loss: 1.75, accuracy: 37.84
train epoch: 3, loss: 1.6914, accuracy: 39.25
 test epoch: 3, loss: 1.64, accuracy: 41.73
train epoch: 4, loss: 1.5805, accuracy: 43.42
 test epoch: 4, loss: 1.55, accuracy: 44.54
train epoch: 5, loss: 1.4915, accuracy: 46.45
 test epoch: 5, loss: 1.44, accuracy: 48.05
train epoch: 6, loss: 1.4300, accuracy: 48.78
 test epoch: 6, loss: 1.40, accuracy: 49.95
train epoch: 7, loss: 1.3789, accuracy: 50.39
 test epoch: 7, loss: 1.40, accuracy: 50.16
train epoch: 8, loss: 1.3330, accuracy: 52.33
 test epoch: 8, loss: 1.31, accuracy: 53.27
train epoch: 9, loss: 1.2835, accuracy: 54.05
 test epoch: 9, loss: 1.28, accuracy: 54.60
train epoch: 10, loss: 1.2389, accuracy: 55.63
 test epoch: 10, loss: 1.23, accuracy: 56.15
train epoch: 11, loss: 1.1966, accuracy: 57.31
 test epoch: 11, loss: 1.22, accuracy: 56.67
train 

  0%|          | 0/80 [00:00<?, ?it/s]

train epoch: 1, loss: 2.1216, accuracy: 19.87
 test epoch: 1, loss: 1.90, accuracy: 31.68
train epoch: 2, loss: 1.8063, accuracy: 35.29
 test epoch: 2, loss: 1.72, accuracy: 38.61
train epoch: 3, loss: 1.6508, accuracy: 40.87
 test epoch: 3, loss: 1.61, accuracy: 41.77
train epoch: 4, loss: 1.5503, accuracy: 44.11
 test epoch: 4, loss: 1.52, accuracy: 45.32
train epoch: 5, loss: 1.4520, accuracy: 47.52
 test epoch: 5, loss: 1.45, accuracy: 48.21
train epoch: 6, loss: 1.3922, accuracy: 49.68
 test epoch: 6, loss: 1.37, accuracy: 50.70
train epoch: 7, loss: 1.3309, accuracy: 52.21
 test epoch: 7, loss: 1.31, accuracy: 52.78
train epoch: 8, loss: 1.2823, accuracy: 54.05
 test epoch: 8, loss: 1.28, accuracy: 54.27
train epoch: 9, loss: 1.2275, accuracy: 56.39
 test epoch: 9, loss: 1.22, accuracy: 56.29
train epoch: 10, loss: 1.1748, accuracy: 58.23
 test epoch: 10, loss: 1.20, accuracy: 57.15
train epoch: 11, loss: 1.1284, accuracy: 59.98
 test epoch: 11, loss: 1.13, accuracy: 59.37
train 

  0%|          | 0/80 [00:00<?, ?it/s]

train epoch: 1, loss: 2.1026, accuracy: 21.54
 test epoch: 1, loss: 1.89, accuracy: 33.22
train epoch: 2, loss: 1.8210, accuracy: 34.88
 test epoch: 2, loss: 1.74, accuracy: 38.11
train epoch: 3, loss: 1.6533, accuracy: 40.76
 test epoch: 3, loss: 1.60, accuracy: 42.68
train epoch: 4, loss: 1.5263, accuracy: 45.13
 test epoch: 4, loss: 1.47, accuracy: 47.26
train epoch: 5, loss: 1.4552, accuracy: 47.70
 test epoch: 5, loss: 1.43, accuracy: 49.43
train epoch: 6, loss: 1.3941, accuracy: 49.96
 test epoch: 6, loss: 1.39, accuracy: 50.68
train epoch: 7, loss: 1.3289, accuracy: 52.20
 test epoch: 7, loss: 1.33, accuracy: 52.15
train epoch: 8, loss: 1.2882, accuracy: 53.79
 test epoch: 8, loss: 1.28, accuracy: 54.27
train epoch: 9, loss: 1.2343, accuracy: 55.90
 test epoch: 9, loss: 1.25, accuracy: 55.89
train epoch: 10, loss: 1.1940, accuracy: 57.40
 test epoch: 10, loss: 1.21, accuracy: 56.77
train epoch: 11, loss: 1.1503, accuracy: 59.03
 test epoch: 11, loss: 1.20, accuracy: 56.98
train 

In [None]:
import pickle
models_params = [m.params for m in models_saved]
with open("cifar_real_nonrobust.txt", "wb") as fp:   #Pickling
  pickle.dump(models_params, fp)

In [None]:
#computing the number of parameters
tot_params = 0
for chiave in models_saved[0].params['params'].keys():
  for sotto_chiave in models_saved[0].params['params'][chiave]:
    print(chiave, ' \t', 
          sotto_chiave, '  \t', 
          models_saved[0].params['params'][chiave][sotto_chiave].size, '    \t', 
          models_saved[0].params['params'][chiave][sotto_chiave].dtype)
    tot_params += models_saved[0].params['params'][chiave][sotto_chiave].size
print('tot: ', tot_params)

Conv_0  	 bias   	 32     	 float64
Conv_0  	 kernel   	 864     	 float64
Conv_1  	 bias   	 32     	 float64
Conv_1  	 kernel   	 9216     	 float64
Conv_2  	 bias   	 64     	 float64
Conv_2  	 kernel   	 18432     	 float64
Conv_3  	 bias   	 64     	 float64
Conv_3  	 kernel   	 36864     	 float64
Conv_4  	 bias   	 128     	 float64
Conv_4  	 kernel   	 73728     	 float64
Conv_5  	 bias   	 128     	 float64
Conv_5  	 kernel   	 147456     	 float64
Dense_0  	 bias   	 128     	 float64
Dense_0  	 kernel   	 262144     	 float64
Dense_1  	 bias   	 10     	 float64
Dense_1  	 kernel   	 1280     	 float64
tot:  550570
