# Tessarine valued CNN with Real Batch Normalization

In [1]:
import numpy as np
import tensorflow as tf
import tensorflow.keras as kr
import matplotlib.pyplot as plt
from tensorflow.keras import datasets, models 
from tensorflow.keras import layers, activations, initializers, regularizers
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras.engine.input_spec import InputSpec
from functools import partial
from keras.preprocessing.image import ImageDataGenerator
from keras.datasets import cifar10
from sklearn.metrics import accuracy_score
from keras.callbacks import LearningRateScheduler, ReduceLROnPlateau
from tensorflow.python.ops import variables as tf_variables

In [2]:
(Xtr_cifar, ytr_cifar), (Xte_cifar, yte_cifar) = cifar10.load_data()
Xtr_cifar = Xtr_cifar / 255
Xte_cifar = Xte_cifar / 255

Xtr_cifar -= np.mean(Xtr_cifar, axis=0)
Xte_cifar -= np.mean(Xtr_cifar, axis=0)

n_classes = 10
ytr_cifar = kr.utils.to_categorical(ytr_cifar, num_classes=n_classes)
yte_cifar = kr.utils.to_categorical(yte_cifar, num_classes=n_classes)

In [3]:
def learning_rate(epoch):
    lr = 1e-2
    if epoch < 151 and epoch > 9:
        lr *= 10.
    elif epoch > 199:
        lr /= 10.
    print('Learning rate: ', lr)
    return lr

In [4]:
def _compute_fans(shape):
    """Computes the number of input and output units for a weight shape.
    Args:
        shape: Integer shape tuple or TF tensor shape.
    Returns:
        A tuple of integer scalars (fan_in, fan_out).

    Extracted from tensorflow/keras/initializers. Available at
    https://github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/python/keras/initializers/initializers_v2.py
    """

    if len(shape) < 1:  # Just to avoid errors for constants.
        fan_in = fan_out = 1
    elif len(shape) == 1:
        fan_in = fan_out = shape[0]
    elif len(shape) == 2:
        fan_in = shape[0]
        fan_out = shape[1]
    else:
        # Assuming convolution kernels (2D, 3D, or more).
        # kernel shape: (..., input_depth, depth)
        receptive_field_size = 1
        for dim in shape[:-2]:
            receptive_field_size *= dim
        fan_in = shape[-2] * receptive_field_size
        fan_out = shape[-1] * receptive_field_size
    return int(fan_in), int(fan_out)


class Hypercomplex4DInitializer(initializers.Initializer):
    """
    Computes initialization based on quaternion variance.
    Options: he uniform, he normal, glorot uniform, glorot normal.
    References:
    [1] He, K., Zhang, X., Ren, S., and Sun, J. (2015b).  Delving deep into rectifiers: Surpassing human-level performance on imagenet classification.
    [2] Glorot, X. and Bengio, Y. (2010).  Understanding the difficulty of training deep feedforward neural networks.  
    In Teh, Y. W. and Titterington, M., editors, Proceedings of the Thirteenth International Conference on Artificial Intelligence and Statistics,
    volume 9 of Proceedings of Machine Learning Research, pages 249–256, Chia Laguna Resort, Sardinia, Italy. PMLR.
    """
    def __init__(self, criterion='he', distribution='uniform', seed=31337):
        self.criterion = criterion
        self.distribution = distribution
        self.seed = seed

    def __call__(self, shape, dtype):
        fan_in, fan_out = _compute_fans(shape)

        if self.criterion == 'he':
            std = 1. / np.sqrt(2 * fan_in)
        elif self.criterion == 'glorot':
            std = 1. / np.sqrt(2 * (fan_in + fan_out))
        else:
            raise ValueError("Chosen criterion was not identified.")

        if self.distribution == 'normal':
            return tf.random.normal(shape, mean=0, stddev=std, dtype=dtype, seed=self.seed)
        elif self.distribution == 'uniform':
            lim = std * np.sqrt(3)
            return tf.random.uniform(shape, minval=-lim, maxval=lim, dtype=dtype, seed=self.seed)
        else:
            raise ValueError("Chosen distribution was not identified")

In [5]:
class TessConv2D(layers.Layer):
    """
    Tessarine valued 2D convolution layer.
    References:
    [1] Trabelsi, C., Bilaniuk, O., Serdyuk, D., Subramanian, S., Santos, J. F., Mehri, S., Rostamzadeh, N., Bengio, Y., and Pal, C. J. (2017). Deep complex networks.
    [2] Gaudet, C. and Maida, A. (2017). Deep quaternion networks.
    """

    def __init__(self, 
                 filters, 
                 kernel_size, 
                 strides=1, 
                 padding='SAME',
                 use_bias=False,
                 activation=None,
                 initializer=Hypercomplex4DInitializer,
                 data_format=None,
                 kernel_regularizer=1e-3):
        super(TessConv2D, self).__init__()
        self.filters = filters 
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding
        self.use_bias = use_bias
        self.activation = activations.get(activation)
        self.initializer = initializer
        self.data_format = data_format
        if kernel_regularizer is not None:
            self.kernel_regularizer = kr.regularizers.l2(kernel_regularizer)
        else:
            self.kernel_regularizer = kernel_regularizer

    def _get_channel_axis(self):
        if self.data_format == 'channels_first':
            raise ValueError('TessConv2d is designed only for channels_last. '
                             'The input has been changed to channels last!')
        else:
            return -1
    
    def _get_input_channel(self, input_shape):
        channel_axis = self._get_channel_axis()
        if input_shape.dims[channel_axis].value is None:
            raise ValueError('The channel dimension of the inputs '
                             'should be defined. Found `None`.')
        return int(input_shape[channel_axis])

    def build(self, input_shape):
        input_shape = tensor_shape.TensorShape(input_shape)
        input_channel = self._get_input_channel(input_shape)
        if input_channel % 4 != 0:
            raise ValueError('The number of input channels must be divisible by 4.')
    
        input_dim = input_channel // 4
        kernel_shape = self.kernel_size + (input_dim, self.filters)
        
        self.f0 = self.add_weight(
            name='real_kernel',
            shape=kernel_shape,
            initializer=self.initializer,
            trainable=True,
            regularizer=self.kernel_regularizer
        )
        self.f1 = self.add_weight(
            name='imag_i_kernel',
            shape=kernel_shape,
            initializer=self.initializer,
            trainable=True,
            regularizer=self.kernel_regularizer
        )
        self.f2 = self.add_weight(
            name='imag_j_kernel',
            shape=kernel_shape,
            initializer=self.initializer,
            trainable=True,
            regularizer=self.kernel_regularizer
        )
        self.f3 = self.add_weight(
            name='imag_k_kernel',
            shape=kernel_shape,
            initializer=self.initializer,
            trainable=True,
            regularizer=self.kernel_regularizer
        )
        
        if self.use_bias:
            self.bias = self.add_weight(
                name='bias',
                shape=(4*self.filters,),
                initializer="zeros",
                trainable=True,
                dtype=self.dtype)
        else:
            self.bias = None

    def call(self, inputs):
        F_r = tf.concat([ self.f0,-self.f1, self.f2,-self.f3],axis=2)
        F_i = tf.concat([ self.f1, self.f0, self.f3, self.f2],axis=2)
        F_j = tf.concat([ self.f2,-self.f3, self.f0,-self.f1],axis=2)
        F_k = tf.concat([ self.f3, self.f2, self.f1, self.f0],axis=2)
               
        y_r = tf.nn.conv2d(inputs, F_r, strides=self.strides, padding=self.padding)
        y_i = tf.nn.conv2d(inputs, F_i, strides=self.strides, padding=self.padding)
        y_j = tf.nn.conv2d(inputs, F_j, strides=self.strides, padding=self.padding)
        y_k = tf.nn.conv2d(inputs, F_k, strides=self.strides, padding=self.padding)
        
        outputs = tf.concat([y_r, y_i, y_j, y_k],axis=3)
        
        if self.use_bias:
            outputs = tf.nn.bias_add(outputs,self.bias)
            
        if self.activation is not None:
            outputs = self.activation(outputs)
        
        return outputs

In [6]:
DefaultConvTess = partial(TessConv2D, kernel_size=(3,3), strides=1, padding="SAME", kernel_regularizer=1e-3, use_bias=False)

In [7]:
class TessResidualUnit(layers.Layer):
    """ 
    Tessarine valued residual unit.
    References:
    [1] He, K., Zhang, X., Ren, S., and Sun, J. (2015).  Deep residual learning for image recog-nition.
    [2] He, K., Zhang, X., Ren, S., and Sun, J. (2016).  Identity mappings in deep residual net-works.
    """
    def __init__(self, filters, strides=1, activation="elu", **kwargs):
        super().__init__(**kwargs)
        self.activation = kr.activations.get(activation)
        self.main_layers = [
            DefaultConvTess(filters, strides=strides),
            kr.layers.BatchNormalization(),
            self.activation,
            DefaultConvTess(filters),
            kr.layers.BatchNormalization(),
            self.activation]
        self.skip_layers = []
        if strides > 1:
            self.skip_layers = [
                DefaultConvTess(filters, kernel_size=(1,1), strides=strides),
                kr.layers.BatchNormalization()]

    def call(self, inputs):
        Z = inputs
        for layer in self.main_layers:
            Z = layer(Z)
        skip_Z = inputs
        for layer in self.skip_layers:
            skip_Z = layer(skip_Z)
        return self.activation(Z + skip_Z)

In [8]:
activation = "elu"

tess_resnet = kr.models.Sequential()
tess_resnet.add(DefaultConvTess(6, kernel_size=(3,3), strides=1))
prev_filters = 6
for filters in [6] * 3 + [12] * 2 + [24] * 2:
    strides = 1 if filters == prev_filters else 2
    tess_resnet.add(TessResidualUnit(filters, strides=strides, activation=activation))
    prev_filters = filters
tess_resnet.add(kr.layers.GlobalAvgPool2D())
tess_resnet.add(kr.layers.Flatten())
tess_resnet.add(kr.layers.Dense(10, activation="softmax"))
tess_resnet.compile(loss="categorical_crossentropy", optimizer=kr.optimizers.SGD(learning_rate=learning_rate(0)), metrics=["accuracy"])

Learning rate:  0.01


In [9]:
lr_scheduler = LearningRateScheduler(learning_rate)

lr_reducer = ReduceLROnPlateau(factor=np.sqrt(0.1),
                               cooldown=0,
                               patience=5,
                               min_lr=1e-4)
callbacks = [lr_reducer,lr_scheduler]

In [10]:
Xtr_rgb = np.append(np.zeros([Xtr_cifar.shape[0],32,32,1]),Xtr_cifar,axis=3)
Xte_rgb = np.append(np.zeros([Xte_cifar.shape[0],32,32,1]),Xte_cifar,axis=3)

In [11]:
datagen = ImageDataGenerator(width_shift_range=0.125, height_shift_range=0.125, horizontal_flip=True)
datagen.fit(Xtr_rgb)

In [12]:
batch_size = 128
epochs = 250

tess_resnet.fit(datagen.flow(Xtr_rgb, ytr_cifar, batch_size=batch_size), epochs=epochs, validation_data=(Xte_rgb,yte_cifar), callbacks=callbacks)

Epoch 1/250
Learning rate:  0.01
Epoch 2/250
Learning rate:  0.01
Epoch 3/250
Learning rate:  0.01
Epoch 4/250
Learning rate:  0.01
Epoch 5/250
Learning rate:  0.01
Epoch 6/250
Learning rate:  0.01
Epoch 7/250
Learning rate:  0.01
Epoch 8/250
Learning rate:  0.01
Epoch 9/250
Learning rate:  0.01
Epoch 10/250
Learning rate:  0.01
Epoch 11/250
Learning rate:  0.1
Epoch 12/250
Learning rate:  0.1
Epoch 13/250
Learning rate:  0.1
Epoch 14/250
Learning rate:  0.1
Epoch 15/250
Learning rate:  0.1
Epoch 16/250
Learning rate:  0.1
Epoch 17/250
Learning rate:  0.1
Epoch 18/250
Learning rate:  0.1
Epoch 19/250
Learning rate:  0.1
Epoch 20/250
Learning rate:  0.1
Epoch 21/250
Learning rate:  0.1
Epoch 22/250
Learning rate:  0.1
Epoch 23/250
Learning rate:  0.1
Epoch 24/250
Learning rate:  0.1
Epoch 25/250
Learning rate:  0.1
Epoch 26/250
Learning rate:  0.1
Epoch 27/250
Learning rate:  0.1
Epoch 28/250
Learning rate:  0.1
Epoch 29/250
Learning rate:  0.1
Epoch 30/250
Learning rate:  0.1
Epoch 31/

<tensorflow.python.keras.callbacks.History at 0x7fc0801e3690>

In [13]:
tess_resnet.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
tess_conv2d (TessConv2D)     (None, None, None, 24)    216       
_________________________________________________________________
tess_residual_unit (TessResi (None, None, None, 24)    2784      
_________________________________________________________________
tess_residual_unit_1 (TessRe (None, None, None, 24)    2784      
_________________________________________________________________
tess_residual_unit_2 (TessRe (None, None, None, 24)    2784      
_________________________________________________________________
tess_residual_unit_3 (TessRe (None, None, None, 48)    8640      
_________________________________________________________________
tess_residual_unit_4 (TessRe (None, None, None, 48)    10752     
_________________________________________________________________
tess_residual_unit_5 (TessRe (None, None, None, 96)    3

In [14]:
accuracy_score(np.argmax(ytr_cifar, axis=1),np.argmax(tess_resnet.predict(Xtr_rgb),axis=1))

0.9646

In [15]:
accuracy_score(np.argmax(yte_cifar, axis=1),np.argmax(tess_resnet.predict(Xte_rgb),axis=1))

0.8463