# Intro notebook to the dense Self-Organizing-Feature layer

## Import and setup

In [1]:
%env TF_CPP_MIN_LOG_LEVEL=3

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

env: TF_CPP_MIN_LOG_LEVEL=3


## Load and preprocess the dataset

In [2]:
(ds_train, ds_test), ds_info = tfds.load(
    "mnist",
    split=["train", "test"],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)


def normalize_img(image, label):
    """Normalizes images: `uint8` -> `float32`."""
    return tf.reshape(tf.cast(image, tf.float32) / 255.0, (28, 28, 1)), label


# train dataset
ds_train = ds_train.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits["train"].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

# test_dataset:
ds_test = ds_test.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

## Define the dense layer with Self-Organinzing-Feature concept

In [3]:
class DenseSOCLayer(tf.keras.layers.Layer):
    def __init__(self, features=32, **kwargs):
        super().__init__(**kwargs)
        self.features = features

    def build(self, input_shape):
        self.projection = self.add_weight(
            shape=(self.features, input_shape[-1]),
            initializer=tf.keras.initializers.VarianceScaling(
                scale=1.0, mode="fan_out", distribution="uniform", seed=None
            ),
            trainable=True,
        )
        self.scale_diag = self.add_weight(
            shape=(self.features, input_shape[-1]),
            initializer="ones",
            trainable=True,
        )

    @staticmethod
    def _log_prob(mu, scale_diag, x, unnormalized):
        log_unnormalized = -0.5 * tf.math.squared_difference(
            x / scale_diag, mu / scale_diag
        )
        if unnormalized:
            return tf.reduce_sum(log_unnormalized)
        log_normalization = tf.constant(
            0.5 * np.log(2.0 * np.pi), dtype=mu.dtype
        ) + tf.math.log(scale_diag)
        return tf.reduce_sum(log_unnormalized - log_normalization)

    def log_prob(self, mu, scale_diag, x, unnormalized=True):
        batch_log_prob = tf.vectorized_map(
            lambda _x: tf.vectorized_map(
                lambda _params: self._log_prob(
                    _params[0], _params[1], _x, unnormalized=unnormalized
                ),
                (mu, scale_diag),
            ),
            x,
        )
        return batch_log_prob

    def call(self, inputs):
        log_probs = self.log_prob(self.projection, self.scale_diag, inputs)
        return log_probs

## Define the model with a single SOF layer

In [4]:
model = tf.keras.models.Sequential(
    [
        tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
        tf.keras.layers.Dense(40),
        tf.keras.layers.LayerNormalization(),
        DenseSOCLayer(features=10),
        tf.keras.layers.LayerNormalization(),
        tf.keras.layers.Dense(10),
    ]
)

model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
    run_eagerly=False,
)
model.summary()

model.fit(
    ds_train,
    epochs=15,
    validation_data=ds_test,
    verbose=2,
)

print("\nEval:")
_ = model.evaluate(ds_test, verbose=2)

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 flatten (Flatten)           (None, 784)               0         
                                                                 
 dense (Dense)               (None, 40)                31400     
                                                                 
 layer_normalization (LayerN  (None, 40)               80        
 ormalization)                                                   
                                                                 
 dense_soc_layer (DenseSOCLa  (None, 10)               800       
 yer)                                                            
                                                                 
 layer_normalization_1 (Laye  (None, 10)               20        
 rNormalization)                                                 
                                                        

## Define baseline model without SOF layer.
### The layer replacing it has twice as many features to have the same parameter count.

In [5]:
model = tf.keras.models.Sequential(
    [
        tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
        tf.keras.layers.Dense(40, activation="relu"),
        tf.keras.layers.LayerNormalization(),
        tf.keras.layers.Dense(20, activation="relu"),  # I double the features to have the ~same parameter count.
        tf.keras.layers.LayerNormalization(),
        tf.keras.layers.Dense(10),
    ]
)
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.summary()

model.fit(
    ds_train,
    epochs=15,
    validation_data=ds_test,
    verbose=2,
)

print("\nEval:")
_ = model.evaluate(ds_test, verbose=2)

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 flatten_1 (Flatten)         (None, 784)               0         
                                                                 
 dense_2 (Dense)             (None, 40)                31400     
                                                                 
 layer_normalization_2 (Laye  (None, 40)               80        
 rNormalization)                                                 
                                                                 
 dense_3 (Dense)             (None, 20)                820       
                                                                 
 layer_normalization_3 (Laye  (None, 20)               40        
 rNormalization)                                                 
                                                                 
 dense_4 (Dense)             (None, 10)               