# Quaternion Valued CNN with Hypercomplex Batch Normalization

In [1]:
import numpy as np
import tensorflow as tf
import tensorflow.keras as kr
import matplotlib.pyplot as plt
from tensorflow.keras import datasets, models 
from tensorflow.keras import layers, activations, initializers, regularizers
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras.engine.input_spec import InputSpec
from functools import partial
from keras.preprocessing.image import ImageDataGenerator
from keras.datasets import cifar10
from sklearn.metrics import accuracy_score
from keras.callbacks import LearningRateScheduler, ReduceLROnPlateau
from tensorflow.python.ops import variables as tf_variables

In [2]:
(Xtr_cifar, ytr_cifar), (Xte_cifar, yte_cifar) = cifar10.load_data()
Xtr_cifar = Xtr_cifar / 255
Xte_cifar = Xte_cifar / 255

Xtr_cifar -= np.mean(Xtr_cifar, axis=0)
Xte_cifar -= np.mean(Xtr_cifar, axis=0)

n_classes = 10
ytr_cifar = kr.utils.to_categorical(ytr_cifar, num_classes=n_classes)
yte_cifar = kr.utils.to_categorical(yte_cifar, num_classes=n_classes)

In [3]:
def learning_rate(epoch):
    lr = 1e-2
    if epoch < 151 and epoch > 9:
        lr *= 10.
    elif epoch > 199:
        lr /= 10.
    print('Learning rate: ', lr)
    return lr

In [4]:
def _compute_fans(shape):
    """Computes the number of input and output units for a weight shape.
    Args:
        shape: Integer shape tuple or TF tensor shape.
    Returns:
        A tuple of integer scalars (fan_in, fan_out).

    Extracted from tensorflow/keras/initializers. Available at
    https://github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/python/keras/initializers/initializers_v2.py
    """

    if len(shape) < 1:  # Just to avoid errors for constants.
        fan_in = fan_out = 1
    elif len(shape) == 1:
        fan_in = fan_out = shape[0]
    elif len(shape) == 2:
        fan_in = shape[0]
        fan_out = shape[1]
    else:
        # Assuming convolution kernels (2D, 3D, or more).
        # kernel shape: (..., input_depth, depth)
        receptive_field_size = 1
        for dim in shape[:-2]:
            receptive_field_size *= dim
        fan_in = shape[-2] * receptive_field_size
        fan_out = shape[-1] * receptive_field_size
    return int(fan_in), int(fan_out)


class Hypercomplex4DInitializer(initializers.Initializer):
    """
    Computes initialization based on quaternion variance.
    Options: he uniform, he normal, glorot uniform, glorot normal.
    References:
    [1] He, K., Zhang, X., Ren, S., and Sun, J. (2015b).  Delving deep into rectifiers: Surpassing human-level performance on imagenet classification.
    [2] Glorot, X. and Bengio, Y. (2010).  Understanding the difficulty of training deep feedforward neural networks.  
    In Teh, Y. W. and Titterington, M., editors, Proceedings of the Thirteenth International Conference on Artificial Intelligence and Statistics,
    volume 9 of Proceedings of Machine Learning Research, pages 249–256, Chia Laguna Resort, Sardinia, Italy. PMLR.
    """
    def __init__(self, criterion='he', distribution='uniform', seed=31337):
        self.criterion = criterion
        self.distribution = distribution
        self.seed = seed

    def __call__(self, shape, dtype):
        fan_in, fan_out = _compute_fans(shape)

        if self.criterion == 'he':
            std = 1. / np.sqrt(2 * fan_in)
        elif self.criterion == 'glorot':
            std = 1. / np.sqrt(2 * (fan_in + fan_out))
        else:
            raise ValueError("Chosen criterion was not identified.")

        if self.distribution == 'normal':
            return tf.random.normal(shape, mean=0, stddev=std, dtype=dtype, seed=self.seed)
        elif self.distribution == 'uniform':
            lim = std * np.sqrt(3)
            return tf.random.uniform(shape, minval=-lim, maxval=lim, dtype=dtype, seed=self.seed)
        else:
            raise ValueError("Chosen distribution was not identified")

In [5]:
class QuatConv2D(layers.Layer):
    """
    Quaternion valued 2D convolution layer.
    References:
    [1] Trabelsi, C., Bilaniuk, O., Serdyuk, D., Subramanian, S., Santos, J. F., Mehri, S., Rostamzadeh, N., Bengio, Y., and Pal, C. J. (2017). Deep complex networks.
    [2] Gaudet, C. and Maida, A. (2017). Deep quaternion networks.
    """
    def __init__(self, 
                 filters, 
                 kernel_size, 
                 strides=1, 
                 padding='SAME',
                 use_bias=False,
                 activation=None,
                 initializer=Hypercomplex4DInitializer(),
                 data_format=None,
                 kernel_regularizer=1e-4):
        super(QuatConv2D, self).__init__()
        self.filters = filters 
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding
        self.use_bias = use_bias
        self.activation = activations.get(activation)
        self.initializer = initializer
        self.data_format = data_format
        self.kernel_regularizer = kernel_regularizer
        
    def _get_channel_axis(self):
        if self.data_format == 'channels_first':
            raise ValueError('QuatConv2d is designed only for channels_last. '
                             'The input has been changed to channels last!')
        else:
            return -1
    
    def _get_input_channel(self, input_shape):
        channel_axis = self._get_channel_axis()
        if input_shape.dims[channel_axis].value is None:
            raise ValueError('The channel dimension of the inputs '
                             'should be defined. Found `None`.')
        return int(input_shape[channel_axis])

    def build(self, input_shape):
        input_shape = tensor_shape.TensorShape(input_shape)
        input_channel = self._get_input_channel(input_shape)
        if input_channel % 4 != 0:
            raise ValueError('The number of input channels must be divisible by 4.')
    
        input_dim = input_channel // 4
        kernel_shape = self.kernel_size + (input_dim, self.filters)
        
        self.f0 = self.add_weight(
            name='real_kernel',
            shape=kernel_shape,
            initializer=self.initializer,
            trainable=True,
            regularizer=kr.regularizers.l2(self.kernel_regularizer)
        )
        self.f1 = self.add_weight(
            name='imag_i_kernel',
            shape=kernel_shape,
            initializer=self.initializer,
            trainable=True,
            regularizer=kr.regularizers.l2(self.kernel_regularizer)
        )
        self.f2 = self.add_weight(
            name='imag_j_kernel',
            shape=kernel_shape,
            initializer=self.initializer,
            trainable=True,
            regularizer=kr.regularizers.l2(self.kernel_regularizer)
        )
        self.f3 = self.add_weight(
            name='imag_k_kernel',
            shape=kernel_shape,
            initializer=self.initializer,
            trainable=True,
            regularizer=kr.regularizers.l2(self.kernel_regularizer)
        )
        
        if self.use_bias:
            self.bias = self.add_weight(
                name='bias',
                shape=(4*self.filters,),
                initializer="zeros",
                trainable=True,
                dtype=self.dtype)
        else:
            self.bias = None

    def call(self, inputs):
        # Filter multiplied from the right!
        F_r = tf.concat([ self.f0,-self.f1,-self.f2,-self.f3],axis=2)
        F_i = tf.concat([ self.f1, self.f0, self.f3,-self.f2],axis=2)
        F_j = tf.concat([ self.f2,-self.f3, self.f0, self.f1],axis=2)
        F_k = tf.concat([ self.f3, self.f2,-self.f1, self.f0],axis=2)
               
        y_r = tf.nn.conv2d(inputs, F_r, strides=self.strides, padding=self.padding)
        y_i = tf.nn.conv2d(inputs, F_i, strides=self.strides, padding=self.padding)
        y_j = tf.nn.conv2d(inputs, F_j, strides=self.strides, padding=self.padding)
        y_k = tf.nn.conv2d(inputs, F_k, strides=self.strides, padding=self.padding)
        
        outputs = tf.concat([y_r, y_i, y_j, y_k],axis=3)
        
        if self.use_bias:
            outputs = tf.nn.bias_add(outputs,self.bias)
            
        if self.activation is not None:
            outputs = self.activation(outputs)
        
        return outputs

In [6]:
def diag_init(shape, dtype=None):
    return tf.ones(shape) / 2.

class Hypercomplex4DBNActivation(layers.Layer):
    """
    Batch Normalization for tessarines and quaternions.
    Based on matrix whitening. Decorrelates each component of tessarine/quaternion.
    Includes activation: can be placed before, after or in the middle of BN.
    References:
    [1] Ioffe, S. and Szegedy, C. (2015). Batch normalization: Accelerating deep network training by reducing internal covariate shift.
    [2] Kessy, A., Lewin, A., and Strimmer, K. (2018). Optimal whitening and decorrelation. The American Statistician, 72(4):309–314.
    [3] Trabelsi, C., Bilaniuk, O., Serdyuk, D., Subramanian, S., Santos, J. F., Mehri, S., Ros-tamzadeh, N., Bengio, Y., and Pal, C. J. (2017). Deep complex networks.
    [4] Gaudet, C. and Maida, A. (2017). Deep quaternion networks.
    """
    def __init__(self,
                 center=True,
                 scale=True,
                 momentum=0.9,
                 beta_init='zeros',
                 gam_diag_init='diag_init',
                 gam_off_init='zeros',
                 mov_mean_init='zeros',
                 mov_var_init='diag_init',
                 mov_cov_init='zeros',
                 beta_reg=None,
                 gam_diag_reg=None,
                 gam_off_reg=None,
                 activation="elu",
                 activation_position="after",
                 epsilon=1e-6,
                 **kwargs):
        super(Hypercomplex4DBNActivation, self).__init__(**kwargs)
        self.center = center
        self.scale = scale
        self.momentum = momentum
        self.beta_init = initializers.get(beta_init)

        if gam_diag_init == 'diag_init':
            self.gam_diag_init = diag_init
        else:
            self.gam_diag_init = initializers.get(gam_diag_init)

        self.gam_off_init = initializers.get(gam_off_init)
        self.mov_mean_init = initializers.get(mov_mean_init)

        if mov_var_init == 'diag_init':
            self.mov_var_init = diag_init
        else:    
            self.mov_var_init = initializers.get(mov_var_init)
            
        self.mov_cov_init = initializers.get(mov_cov_init)
        self.beta_reg = regularizers.get(beta_reg)
        self.gam_diag_reg = regularizers.get(gam_diag_reg)
        self.gam_off_reg = regularizers.get(gam_off_reg)
        self.activation = activations.get(activation)
        self.activation_position = activation_position
        self.epsilon = epsilon

    def build(self, input_shape):
        input_dim = input_shape[-1] // 4
        vars_shape = [input_dim, 1]
        gamma_shape = (input_dim,)

        if self.scale:
            self.mov_Vrr = self.add_weight(shape=vars_shape,
                                           initializer=self.mov_var_init,
                                           trainable=False,
                                           name="mov_Vrr")
            self.mov_Vri = self.add_weight(shape=vars_shape,
                                           initializer=self.mov_cov_init,
                                           trainable=False,
                                           name="mov_Vri")
            self.mov_Vrj = self.add_weight(shape=vars_shape,
                                           initializer=self.mov_cov_init,
                                           trainable=False,
                                           name="mov_Vrj")
            self.mov_Vrk = self.add_weight(shape=vars_shape,
                                           initializer=self.mov_cov_init,
                                           trainable=False,
                                           name="mov_Vrk")
            self.mov_Vii = self.add_weight(shape=vars_shape,
                                           initializer=self.mov_var_init,
                                           trainable=False,
                                           name="mov_Vii")
            self.mov_Vij = self.add_weight(shape=vars_shape,
                                           initializer=self.mov_cov_init,
                                           trainable=False,
                                           name="mov_Vij")
            self.mov_Vik = self.add_weight(shape=vars_shape,
                                           initializer=self.mov_cov_init,
                                           trainable=False,
                                           name="mov_Vik")
            self.mov_Vjj = self.add_weight(shape=vars_shape,
                                           initializer=self.mov_var_init,
                                           trainable=False,
                                           name="mov_Vjj")
            self.mov_Vjk = self.add_weight(shape=vars_shape,
                                           initializer=self.mov_cov_init,
                                           trainable=False,
                                           name="mov_Vjk")
            self.mov_Vkk = self.add_weight(shape=vars_shape,
                                           initializer=self.mov_var_init,
                                           trainable=False,
                                           name="mov_Vkk")

            self.gam_rr = self.add_weight(shape=gamma_shape,
                                          initializer=self.gam_diag_init,
                                          regularizer=self.gam_diag_reg,
                                          name="gam_rr")
            self.gam_ri = self.add_weight(shape=gamma_shape,
                                          initializer=self.gam_off_init,
                                          regularizer=self.gam_off_reg,
                                          name="gam_ri")
            self.gam_rj = self.add_weight(shape=gamma_shape,
                                          initializer=self.gam_off_init,
                                          regularizer=self.gam_off_reg,
                                          name="gam_rj")
            self.gam_rk = self.add_weight(shape=gamma_shape,
                                          initializer=self.gam_off_init,
                                          regularizer=self.gam_off_reg,
                                          name="gam_rk")
            self.gam_ii = self.add_weight(shape=gamma_shape,
                                          initializer=self.gam_diag_init,
                                          regularizer=self.gam_diag_reg,
                                          name="gam_ii")
            self.gam_ij = self.add_weight(shape=gamma_shape,
                                          initializer=self.gam_off_init,
                                          regularizer=self.gam_off_reg,
                                          name="gam_ij")
            self.gam_ik = self.add_weight(shape=gamma_shape,
                                          initializer=self.gam_off_init,
                                          regularizer=self.gam_off_reg,
                                          name="gam_ik")
            self.gam_jj = self.add_weight(shape=gamma_shape,
                                          initializer=self.gam_diag_init,
                                          regularizer=self.gam_diag_reg,
                                          name="gam_jj")
            self.gam_jk = self.add_weight(shape=gamma_shape,
                                          initializer=self.gam_off_init,
                                          regularizer=self.gam_off_reg,
                                          name="gam_jk")
            self.gam_kk = self.add_weight(shape=gamma_shape,
                                          initializer=self.gam_diag_init,
                                          regularizer=self.gam_diag_reg,
                                          name="gam_kk")
        else:
            self.mov_Vrr = None
            self.mov_Vri = None
            self.mov_Vrj = None
            self.mov_Vrk = None
            self.mov_Vii = None
            self.mov_Vij = None
            self.mov_Vik = None
            self.mov_Vjj = None
            self.mov_Vjk = None
            self.mov_Vkk = None
            self.gam_rr = None
            self.gam_ri = None
            self.gam_rj = None
            self.gam_rk = None
            self.gam_ii = None
            self.gam_ij = None
            self.gam_ik = None
            self.gam_jj = None
            self.gam_jk = None
            self.gam_kk = None

        if self.center:
            self.beta = self.add_weight(shape=(1, 1, 1, input_shape[-1]),
                                        initializer=self.beta_init,
                                        regularizer=self.beta_reg,
                                        name="beta")
            self.mov_mean = self.add_weight(shape=(1, 1, 1, input_shape[-1]),
                                            initializer=self.mov_mean_init,
                                            trainable=False,
                                            name="mov_mean")
        else:
            self.beta = None
            self.mov_mean = None

    def _compute_variances(self, centered_r, centered_i, centered_j, centered_k, input_dim):
        Vrr = kr.backend.mean(
            centered_r ** 2,
            axis=[0, 1, 2]
        ) + self.epsilon

        Vri = kr.backend.mean(
            centered_r * centered_i,
            axis=[0, 1, 2]
        )

        Vrj = kr.backend.mean(
            centered_r * centered_j,
            axis=[0, 1, 2]
        )

        Vrk = kr.backend.mean(
            centered_r * centered_k,
            axis=[0, 1, 2]
        )

        Vii = kr.backend.mean(
            centered_i ** 2,
            axis=[0, 1, 2]
        ) + self.epsilon

        Vij = kr.backend.mean(
            centered_i * centered_j,
            axis=[0, 1, 2]
        )

        Vik = kr.backend.mean(
            centered_i * centered_k,
            axis=[0, 1, 2]
        )

        Vjj = kr.backend.mean(
            centered_j ** 2,
            axis=[0, 1, 2]
        ) + self.epsilon

        Vjk = kr.backend.mean(
            centered_j * centered_k,
            axis=[0, 1, 2]
        )

        Vkk = kr.backend.mean(
            centered_k ** 2,
            axis=[0, 1, 2]
        ) + self.epsilon

        pars_shape = [input_dim, 1]
        Vrr = tf.reshape(Vrr, pars_shape)
        Vri = tf.reshape(Vri, pars_shape)
        Vrj = tf.reshape(Vrj, pars_shape)
        Vrk = tf.reshape(Vrk, pars_shape)
        Vii = tf.reshape(Vii, pars_shape)
        Vij = tf.reshape(Vij, pars_shape)
        Vik = tf.reshape(Vik, pars_shape)
        Vjj = tf.reshape(Vjj, pars_shape)
        Vjk = tf.reshape(Vjk, pars_shape)
        Vkk = tf.reshape(Vkk, pars_shape)

        return Vrr, Vri, Vrj, Vrk, Vii, Vij, Vik, Vjj, Vjk, Vkk
    
    def _moving_exponential_update(self, var, value):
        decay = 1 - self.momentum
        var.assign_sub(var * decay)
        var.assign_add(value * decay)

    def _update_moving_parameters(self, mean, Vrr, Vri, Vrj, Vrk, Vii, Vij, Vik, Vjj, Vjk, Vkk):
        if self.center:
            self._moving_exponential_update(self.mov_mean, mean)

        if self.scale:
            self._moving_exponential_update(self.mov_Vrr, Vrr)
            self._moving_exponential_update(self.mov_Vri, Vri)
            self._moving_exponential_update(self.mov_Vrj, Vrj)
            self._moving_exponential_update(self.mov_Vrk, Vrk)
            self._moving_exponential_update(self.mov_Vii, Vii)
            self._moving_exponential_update(self.mov_Vij, Vij)
            self._moving_exponential_update(self.mov_Vik, Vik)
            self._moving_exponential_update(self.mov_Vjj, Vjj)
            self._moving_exponential_update(self.mov_Vjk, Vjk)
            self._moving_exponential_update(self.mov_Vkk, Vkk)

    def call(self, inputs, training=None):
        if (not self.center) or (not self.scale):
            raise ValueError("Batch Normalization should scale or center.")

        input_shape = kr.backend.int_shape(inputs)
        input_dim = input_shape[-1] // 4

        # Activation before
        if self.activation_position == "before":
            output = self.activation(inputs)
        else:
            output = inputs

        if training in {0, False}:
            mean = self.mov_mean
            centered = output - mean

            if self.scale:
                centered_r = centered[:, :, :, :input_dim]
                centered_i = centered[:, :, :, input_dim:input_dim * 2]
                centered_j = centered[:, :, :, input_dim * 2:input_dim * 3]
                centered_k = centered[:, :, :, input_dim * 3:]

                Vrr = self.mov_Vrr
                Vri = self.mov_Vri
                Vrj = self.mov_Vrj
                Vrk = self.mov_Vrk
                Vii = self.mov_Vii
                Vij = self.mov_Vij
                Vik = self.mov_Vik
                Vjj = self.mov_Vjj
                Vjk = self.mov_Vjk
                Vkk = self.mov_Vkk
        else:
            # mean and centering
            mean = kr.backend.mean(output, axis=[0, 1, 2])
            mean = kr.backend.reshape(mean, [1, 1, 1, input_dim * 4])
            centered = output - mean

            if self.scale:
                centered_r = centered[:, :, :, :input_dim]
                centered_i = centered[:, :, :, input_dim:input_dim * 2]
                centered_j = centered[:, :, :, input_dim * 2:input_dim * 3]
                centered_k = centered[:, :, :, input_dim * 3:]

                Vrr, Vri, Vrj, Vrk, Vii, Vij, Vik, Vjj, Vjk, Vkk = self._compute_variances(centered_r, centered_i, centered_j, centered_k, input_dim)
            else:
                Vrr, Vri, Vrj, Vrk, Vii, Vij, Vik, Vjj, Vjk, Vkk = [None for i in range(10)]  

            self._update_moving_parameters(mean, Vrr, Vri, Vrj, Vrk, Vii, Vij, Vik, Vjj, Vjk, Vkk)

        if self.scale:
            var_reshape = [input_dim, 1, 4]
            # covariance matrix
            V = tf.concat([[tf.reshape(tf.concat([Vrr, Vri, Vrj, Vrk], axis=1), var_reshape)],
                           [tf.reshape(tf.concat([Vri, Vii, Vij, Vik], axis=1), var_reshape)],
                           [tf.reshape(tf.concat([Vrj, Vij, Vjj, Vjk], axis=1), var_reshape)],
                           [tf.reshape(tf.concat([Vrk, Vik, Vjk, Vkk], axis=1), var_reshape)]], axis=2)

            # Whitening
            R = tf.reshape(tf.linalg.cholesky(V), [input_dim, 4, 4])
            W = tf.linalg.inv(tf.transpose(R, perm=[0,2,1]))

            Wrr = W[:,0,0]
            Wri = W[:,0,1]
            Wrj = W[:,0,2]
            Wrk = W[:,0,3]
            Wii = W[:,1,1]
            Wij = W[:,1,2]
            Wik = W[:,1,3]
            Wjj = W[:,2,2]
            Wjk = W[:,2,3]
            Wkk = W[:,3,3]

            output_r = centered_r * Wrr
            output_i = centered_r * Wri + centered_i * Wii
            output_j = centered_r * Wrj + centered_i * Wij + centered_j * Wjj
            output_k = centered_r * Wrk + centered_i * Wik + centered_j * Wjk + centered_k * Wkk

            if self.activation_position == "middle":
                output_r = self.activation(output_r)
                output_i = self.activation(output_i)
                output_j = self.activation(output_j)
                output_k = self.activation(output_k)

            out_r = output_r * self.gam_rr
            out_i = output_r * self.gam_ri + output_i * self.gam_ii
            out_j = output_r * self.gam_rj + output_i * self.gam_ij + output_j * self.gam_jj
            out_k = output_r * self.gam_rk + output_i * self.gam_ik + output_j * self.gam_jk + output_k * self.gam_kk
            output = tf.concat([out_r, out_i, out_j, out_k], axis=-1)
        else:
            output = centered
            if self.activation_position == "middle":
                output = self.activation(output)

        if self.center:
            output = output + self.beta

        if self.activation_position == "after":
            output = self.activation(output)

        return output


In [7]:
DefaultConvQuat = partial(QuatConv2D, kernel_size=(3,3), strides=1, padding="SAME", kernel_regularizer=1e-3, use_bias=False)

In [8]:
class QuatResidualUnit(layers.Layer):
    """ 
    Quaternion valued residual unit.
    References:
    [1] He, K., Zhang, X., Ren, S., and Sun, J. (2015).  Deep residual learning for image recognition.
    [2] He, K., Zhang, X., Ren, S., and Sun, J. (2016).  Identity mappings in deep residual networks.
    """
    def __init__(self, filters, strides=1, activation="elu", activation_position="after", **kwargs):
        super().__init__(**kwargs)
        self.activation = kr.activations.get(activation)
        self.main_layers = [
            DefaultConvQuat(filters, strides=strides),
            Hypercomplex4DBNActivation(activation=activation, activation_position=activation_position), 
            DefaultConvQuat(filters),
            Hypercomplex4DBNActivation(activation=activation, activation_position=activation_position),]
        self.skip_layers = []
        if strides > 1:
            self.skip_layers = [
                DefaultConvQuat(filters, kernel_size=(1,1), strides=strides),
                Hypercomplex4DBNActivation(activation=None, activation_position='no_activation')]

    def call(self, inputs):
        Z = inputs
        for layer in self.main_layers:
            Z = layer(Z)
        skip_Z = inputs
        for layer in self.skip_layers:
            skip_Z = layer(skip_Z)
        return self.activation(Z + skip_Z)

In [9]:
activation = "elu"
activation_position = 'after'

quat_resnet = kr.models.Sequential()
quat_resnet.add(DefaultConvQuat(6, kernel_size=(3,3), strides=1))
prev_filters = 6
for filters in [6] * 3 + [12] * 2 + [24] * 2:
    strides = 1 if filters == prev_filters else 2
    quat_resnet.add(QuatResidualUnit(filters, strides=strides, activation=activation, activation_position=activation_position))
    prev_filters = filters
quat_resnet.add(kr.layers.GlobalAvgPool2D())
quat_resnet.add(kr.layers.Flatten())
quat_resnet.add(kr.layers.Dense(10, activation="softmax"))
quat_resnet.compile(loss="categorical_crossentropy", optimizer=kr.optimizers.SGD(learning_rate=learning_rate(0)), metrics=["accuracy"])

Learning rate:  0.01


In [10]:
lr_scheduler = LearningRateScheduler(learning_rate)

lr_reducer = ReduceLROnPlateau(factor=np.sqrt(0.1),
                               cooldown=0,
                               patience=5,
                               min_lr=1e-4)
callbacks = [lr_reducer,lr_scheduler]

In [11]:
Xtr_rgb = np.append(np.zeros([Xtr_cifar.shape[0],32,32,1]),Xtr_cifar,axis=3)
Xte_rgb = np.append(np.zeros([Xte_cifar.shape[0],32,32,1]),Xte_cifar,axis=3)

In [12]:
datagen = ImageDataGenerator(width_shift_range=0.125, height_shift_range=0.125, horizontal_flip=True)
datagen.fit(Xtr_rgb)

In [13]:
batch_size = 128
epochs = 250

quat_resnet.fit(datagen.flow(Xtr_rgb, ytr_cifar, batch_size=batch_size), epochs=epochs, validation_data=(Xte_rgb,yte_cifar), callbacks=callbacks)

Epoch 1/250
Learning rate:  0.01
Epoch 2/250
Learning rate:  0.01
Epoch 3/250
Learning rate:  0.01
Epoch 4/250
Learning rate:  0.01
Epoch 5/250
Learning rate:  0.01
Epoch 6/250
Learning rate:  0.01
Epoch 7/250
Learning rate:  0.01
Epoch 8/250
Learning rate:  0.01
Epoch 9/250
Learning rate:  0.01
Epoch 10/250
Learning rate:  0.01
Epoch 11/250
Learning rate:  0.1
Epoch 12/250
Learning rate:  0.1
Epoch 13/250
Learning rate:  0.1
Epoch 14/250
Learning rate:  0.1
Epoch 15/250
Learning rate:  0.1
Epoch 16/250
Learning rate:  0.1
Epoch 17/250
Learning rate:  0.1
Epoch 18/250
Learning rate:  0.1
Epoch 19/250
Learning rate:  0.1
Epoch 20/250
Learning rate:  0.1
Epoch 21/250
Learning rate:  0.1
Epoch 22/250
Learning rate:  0.1
Epoch 23/250
Learning rate:  0.1
Epoch 24/250
Learning rate:  0.1
Epoch 25/250
Learning rate:  0.1
Epoch 26/250
Learning rate:  0.1
Epoch 27/250
Learning rate:  0.1
Epoch 28/250
Learning rate:  0.1
Epoch 29/250
Learning rate:  0.1
Epoch 30/250
Learning rate:  0.1
Epoch 31/

<tensorflow.python.keras.callbacks.History at 0x7fcdd09306d0>

In [14]:
quat_resnet.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
quat_conv2d (QuatConv2D)     (None, None, None, 24)    216       
_________________________________________________________________
quat_residual_unit (QuatResi (None, None, None, 24)    2928      
_________________________________________________________________
quat_residual_unit_1 (QuatRe (None, None, None, 24)    2928      
_________________________________________________________________
quat_residual_unit_2 (QuatRe (None, None, None, 24)    2928      
_________________________________________________________________
quat_residual_unit_3 (QuatRe (None, None, None, 48)    9072      
_________________________________________________________________
quat_residual_unit_4 (QuatRe (None, None, None, 48)    11040     
_________________________________________________________________
quat_residual_unit_5 (QuatRe (None, None, None, 96)    3

In [15]:
accuracy_score(np.argmax(ytr_cifar, axis=1),np.argmax(quat_resnet.predict(Xtr_rgb),axis=1))

0.96704

In [16]:
accuracy_score(np.argmax(yte_cifar, axis=1),np.argmax(quat_resnet.predict(Xte_rgb),axis=1))

0.8499