# **Image Classification**: A Deep Generalized Convolutional Sum-Product Network (DGC-SPN) with `libspn-keras`
Let's go through an example of building complex SPNs with [`libspn-keras`](https://github.com/pronobis/libspn-keras). The layer-based API of the library makes it straightforward to build advanced SPN architectures.

First let's set up the dependencies:

In [None]:
!pip install libspn-keras tensorflow-datasets -q

## DGC-SPN
A DGC-SPN consists of convolutional product and sum nodes. For the sake of 
demonstration, we'll use a structure that trains relatively quickly, without worrying too much about the final accuracy of the model. 

### Setting the Default Sum Operation

We'll take the default approach to training deep models for classification through gradient-based minimization of
cross-entropy. To ensure our sum operations in the SPN pass down gradients instead of EM signals, we set
the default sum operation to `SumOpGradBackprop`. In fact, this is the default setting of `libspn-keras`:

In [None]:
import libspn_keras as spnk

print(spnk.get_default_sum_op())

spnk.set_default_sum_op(spnk.SumOpGradBackprop())

### Setting the Default Sum Accumulator Initializer

In `libspn-keras`, we refer to the unnormalized weights as _accumulators_. These can be represented in linear space or logspace. Setting the ``SumOp`` also configures the default choice of representation space for these accumulators. For example, gradients should be used in the case of _discriminative_ learning and accumulators are then preferrably represented in logspace. This overcomes the need to project the accumulators to $\mathbb R^+$ after gradient updates, since for log accumulators can take any value in $\mathbb R$ (whereas linear accumulators are limited to $\mathbb R^+$).

To set the default initial value (which will be transformed to logspace internally if needed), one can use `spnk.set_default_accumulator_initializer`:

In [None]:
from tensorflow import keras

spnk.set_default_accumulator_initializer(
    keras.initializers.TruncatedNormal(stddev=0.5, mean=1.0)
)

### Defining the Architecture
We begin by using non-overlapping convolution patches for our first two product layers,
since this way we make sure that no scopes overlap between products. 

For the
remainder of the product layers, we'll use *exponentially increasing dilation rates*. By doing so, we have
'overlapping' patches that still yield a valid SPN. These exponentially increasing dilation rates for convolutional SPNs were first
introduced in [this paper](https://arxiv.org/abs/1902.06155).

In [None]:
sum_product_network = keras.Sequential([
  spnk.layers.NormalizeStandardScore(input_shape=(28, 28, 1)),
  spnk.layers.NormalLeaf(
      num_components=16, 
      location_trainable=True,
      location_initializer=keras.initializers.TruncatedNormal(
          stddev=1.0, mean=0.0)
  ),
  # Non-overlapping products
  spnk.layers.Conv2DProduct(
      depthwise=True, 
      strides=[2, 2], 
      dilations=[1, 1], 
      kernel_size=[2, 2],
      padding='valid'
  ),
  spnk.layers.Local2DSum(num_sums=16),
  # Non-overlapping products
  spnk.layers.Conv2DProduct(
      depthwise=True, 
      strides=[2, 2], 
      dilations=[1, 1], 
      kernel_size=[2, 2],
      padding='valid'
  ),
  spnk.layers.Local2DSum(num_sums=32),
  # Overlapping products, starting at dilations [1, 1]
  spnk.layers.Conv2DProduct(
      depthwise=True, 
      strides=[1, 1], 
      dilations=[1, 1], 
      kernel_size=[2, 2],
      padding='full'
  ),
  spnk.layers.Local2DSum(num_sums=32),
  # Overlapping products, with dilations [2, 2] and full padding
  spnk.layers.Conv2DProduct(
      depthwise=True, 
      strides=[1, 1], 
      dilations=[2, 2], 
      kernel_size=[2, 2],
      padding='full'
  ),
  spnk.layers.Local2DSum(num_sums=64),
  # Overlapping products, with dilations [2, 2] and full padding
  spnk.layers.Conv2DProduct(
      depthwise=True, 
      strides=[1, 1], 
      dilations=[4, 4], 
      kernel_size=[2, 2],
      padding='full'
  ),
  spnk.layers.Local2DSum(num_sums=64),
  # Overlapping products, with dilations [2, 2] and 'final' padding to combine 
  # all scopes
  spnk.layers.Conv2DProduct(
      depthwise=True, 
      strides=[1, 1], 
      dilations=[8, 8], 
      kernel_size=[2, 2],
      padding='final'
  ),
  spnk.layers.SpatialToRegions(),
  # Class roots
  spnk.layers.DenseSum(num_sums=10),
  spnk.layers.RootSum(return_weighted_child_logits=True)
])

sum_product_network.summary()

### Setting up a `tf.Dataset` with `tensorflow_datasets`
Then, we'll configure a train set and a test set using `tensorflow_datasets`. As first suggested in the work by [Poon and Domingos](https://arxiv.org/abs/1202.3732), we whiten each sample by subtracting the mean and dividing by the
standard devation.

In [None]:
sum_product_network = keras.Sequential([
  spnk.layers.NormalizeStandardScore(input_shape=(28, 28, 1)),
  spnk.layers.NormalLeaf(
      num_components=16, 
      location_trainable=True,
      location_initializer=keras.initializers.TruncatedNormal(
          stddev=1.0, mean=0.0)
  ),
  # Non-overlapping products
  spnk.layers.Conv2DProduct(
      depthwise=True, 
      strides=[2, 2], 
      dilations=[1, 1], 
      kernel_size=[2, 2],
      padding='valid'
  ),
  spnk.layers.Local2DSum(num_sums=16),
  # Non-overlapping products
  spnk.layers.Conv2DProduct(
      depthwise=True, 
      strides=[2, 2], 
      dilations=[1, 1], 
      kernel_size=[2, 2],
      padding='valid'
  ),
  spnk.layers.Local2DSum(num_sums=32),
  # Overlapping products, starting at dilations [1, 1]
  spnk.layers.Conv2DProduct(
      depthwise=True, 
      strides=[1, 1], 
      dilations=[1, 1], 
      kernel_size=[2, 2],
      padding='full'
  ),
  spnk.layers.Local2DSum(num_sums=32),
  # Overlapping products, with dilations [2, 2] and full padding
  spnk.layers.Conv2DProduct(
      depthwise=True, 
      strides=[1, 1], 
      dilations=[2, 2], 
      kernel_size=[2, 2],
      padding='full'
  ),
  spnk.layers.Local2DSum(num_sums=64),
  # Overlapping products, with dilations [2, 2] and full padding
  spnk.layers.Conv2DProduct(
      depthwise=True, 
      strides=[1, 1], 
      dilations=[4, 4], 
      kernel_size=[2, 2],
      padding='full'
  ),
  spnk.layers.Local2DSum(num_sums=64),
  # Overlapping products, with dilations [2, 2] and 'final' padding to combine 
  # all scopes
  spnk.layers.Conv2DProduct(
      depthwise=True, 
      strides=[1, 1], 
      dilations=[8, 8], 
      kernel_size=[2, 2],
      padding='final'
  ),
  spnk.layers.SpatialToRegions(),
  # Class roots
  spnk.layers.DenseSum(num_sums=10),
  spnk.layers.RootSum(return_weighted_child_logits=True)
])

sum_product_network.summary()

### Setting up a `tf.Dataset` with `tensorflow_datasets`
Then, we'll configure a train set and a test set using `tensorflow_datasets`. As first suggested in the work by [Poon and Domingos](https://arxiv.org/abs/1202.3732), we whiten each sample by subtracting the mean and dividing by the
standard devation.

In [None]:
import tensorflow_datasets as tfds

batch_size = 32

mnist_train = (
    tfds.load(name="mnist", split="train", as_supervised=True)
    .shuffle(1024)
    .batch(batch_size)
)

mnist_test = (
    tfds.load(name="mnist", split="test", as_supervised=True)
    .batch(100)
)

### Configuring the remaining training components
Now that we have an SPN that produces logits for each class at it root, we can attach a `SparseCategoricalCrossEntropy` loss. Note that we need `from_logits=True` since the SPN computes its outputs in log-space!

optimizer = keras.optimizers.Adam(learning_rate=1e-2)
metrics = [keras.metrics.SparseCategoricalAccuracy()]
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

sum_product_network.compile(loss=loss, metrics=metrics, optimizer=optimizer)

In [None]:
optimizer = keras.optimizers.Adam(learning_rate=1e-2)
metrics = [keras.metrics.SparseCategoricalAccuracy()]
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

sum_product_network.compile(loss=loss, metrics=metrics, optimizer=optimizer)

### Training the SPN
We can simply use the `.fit` function that comes with Keras and pass our `tf.data.Dataset` to it to train!

In [None]:
sum_product_network.fit(mnist_train, epochs=5)
sum_product_network.evaluate(mnist_test)

### Storing the SPN
Finally, we might want to store our SPN, this is again a piece of cake using the `.save_weights` method of `tf.keras.models.Sequential`:

In [None]:
sum_product_network.save_weights('spn_weights.h5')