In [None]:
#| default_exp callbacks

# Callbacks

> Useful callbacks to use with the functional layers.

In [None]:
#| export
import wandb

import tensorflow as tf
from tensorflow.keras.callbacks import Callback

from flayers.layers import *

In [None]:
#| hide
import wandb
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist
from einops import repeat

## Gabor parameter logging

> Logging Gabor parameters into *wandb*

We found that sometimes, during training, an error would rise regarding the inverse of the matrix during the calculation of the Gabor filters. Our first thought was that the covariance matrix (calculated with the parameters `sigma_i` and `sigma_j`) wasn't invertible, meaning that `sigma_i` and `sigma_j` were non-positive, but a constraint on the variables did not fix the problem. To inspect it in more detail, we are going to log all the layer's weights during training to *wandb* to try and find the root of the problem.

> To avoid introducing dependencies that won't be used by many people, we can put the `import wandb` in the instantiation of the callback.

In [None]:
#| export

class GaborLayerLogger(Callback):
    import wandb

    """Logs the gabor parameters into wandb during training."""
    def on_batch_end(self, 
                     batch, # Batch number.
                     logs=None, # Dictionary containing metrics and information of the training.
                     ):
        """Logs the gabor parameters after each batch (after each parameter update)."""
        for layer in self.model.layers:
            if isinstance(layer, GaborLayer):
                for weight in layer.weights:
                    wandb.log({f'{layer.name}.{weight.name}': wandb.Histogram(weight)})

In [None]:
#| export

class GaborLayerSeqLogger(Callback):
    import wandb

    """Logs the gabor parameters into wandb during training."""
    def on_batch_end(self, 
                     batch, # Batch number.
                     logs=None, # Dictionary containing metrics and information of the training.
                     ):
        """Logs the gabor parameters after each batch (after each parameter update)."""
        for layer in self.model.feature_extractor.layers:
            if isinstance(layer, GaborLayer):
                for weight in layer.weights:
                    wandb.log({f'{layer.name}.{weight.name}': wandb.Histogram(weight)})

Let's check if it logs the parameters appropriately:

In [None]:
#| hide
(X_train, Y_train), (X_test, Y_test) = mnist.load_data()

X_train = repeat(X_train, "b h w ->  b h w c", c=1)/255.0
X_test = repeat(X_test, "b h w ->  b h w c", c=1)/255.0

X_train.shape, Y_train.shape, X_test.shape, Y_test.shape

((60000, 28, 28, 1), (60000,), (10000, 28, 28, 1), (10000,))

## Definition of simple model

In [None]:
model = tf.keras.Sequential([
    RandomGabor(n_gabors=4, size=20, input_shape=(28,28,1)),
    layers.MaxPool2D(2),
    layers.GlobalAveragePooling2D(),
    layers.Dense(10, activation="softmax")
])
model.compile(optimizer="adam",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])
model.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
random_gabor_3 (RandomGabor) (None, 28, 28, 4)         1626      
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 14, 14, 4)         0         
_________________________________________________________________
global_average_pooling2d_1 ( (None, 4)                 0         
_________________________________________________________________
dense_1 (Dense)              (None, 10)                50        
Total params: 1,676
Trainable params: 76
Non-trainable params: 1,600
_________________________________________________________________


In [None]:
config = {
    "epochs":5,
    "batch_size":64,
}

In [None]:
wandb.init(project="Testing",
           config=config)
config = wandb.config

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjorgvt[0m (use `wandb login --relogin` to force relogin)


In [None]:
model.fit(X_train, Y_train, epochs=config.epochs, batch_size=config.batch_size, callbacks=[GaborLayerLogger()])

In [None]:
wandb.finish()


