# BPCA Layer
The goal of this notebook is create a layer to perform BPCA on the input data. The layer will be used to replace pooling layers in a CNN. The layer will be implemented in TensorFlow.

In [8]:
import tensorflow as tf


class MinMaxPooling2D(tf.keras.layers.Layer):
    def __init__(self, pool_size=(2, 2), strides=None, **kwargs):
        super(MinMaxPooling2D, self).__init__(**kwargs)
        self.pool_size = pool_size
        self.strides = strides or pool_size

    def call(self, inputs):
        max_pool = tf.keras.backend.pool2d(
            inputs, pool_size=self.pool_size, strides=self.strides, padding='same', pool_mode='max')
        print(max_pool, 'max_pool')

        min_pool = tf.keras.backend.pool2d(
            inputs, pool_size=self.pool_size, strides=self.strides, padding='same', pool_mode='avg')
        print(min_pool, 'max_pool')

        output = tf.keras.layers.concatenate([max_pool, min_pool], axis=-1)
        print(output, 'output')

        return output

    def get_config(self):
        config = super(MinMaxPooling2D, self).get_config()
        config.update({'pool_size': self.pool_size, 'strides': self.strides})
        return config


class BPCALayer(tf.keras.layers.Layer):
    def __init__(self, pool_size=2, stride=2, n_components=1, **kwargs):
        super(BPCALayer, self).__init__(**kwargs)
        self.pool_size = pool_size
        self.stride = stride
        self.n_components = n_components

    def call(self, inputs):
        # make blocks
        batch = inputs.shape[0]
        n = inputs.shape[1]
        m = inputs.shape[2]

        # output = tf.zeros([batch, n // self.pool_size, m // self.pool_size], tf.float32)
        output = tf.constant([], shape=(
            0, n // self.pool_size, m // self.pool_size, 1), dtype=tf.float32)

        for index in range(0, batch):
            input = tf.reshape(inputs[index, :], (int(n), int(m)))

            output_matrix = []
            for i in range(0, n, self.stride):
                for j in range(0, m, self.stride):
                    # fix out of bounds
                    if i + self.pool_size > n or j + self.pool_size > m:
                        continue
                    output_matrix.append(
                        input[i:i + self.pool_size, j:j + self.pool_size])

            output_matrix = tf.convert_to_tensor(output_matrix, tf.float32)
            output_matrix = tf.reshape(
                output_matrix, (output_matrix.shape[0], output_matrix.shape[1]*output_matrix.shape[2]))

            n_components = self.n_components

            data = tf.cast(output_matrix, tf.float32)

            # Normalize the data by subtracting the mean and dividing by the standard deviation
            mean = tf.reduce_mean(data, axis=0)
            std = tf.math.reduce_std(data, axis=0)
            data = (data - mean) / std

            # Perform the Singular Value Decomposition (SVD) on the data
            s, u, v = tf.linalg.svd(data)

            # Extract the first n principal components from the matrix v
            pca_components = v[:, :n_components]

            # Perform the PCA transformation on the data
            transformed_data = tf.matmul(data, pca_components)

            output = tf.concat([output,
                                tf.expand_dims(
                                    tf.reshape(transformed_data, [n // self.pool_size, m // self.pool_size, 1]), axis=0)
                                ], axis=0)

        print(f'output: {output.shape}')

        return output


In [9]:
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten

model = Sequential()

model.add(tf.keras.layers.InputLayer((28, 28, 1), batch_size=32))
model.add(BPCALayer())
# model.add(MinMaxPooling2D())  # replace MaxPooling2D with MinMaxPooling2D
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))

# Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy',
              metrics=['accuracy'])
r = model.fit(x_train, y_train, batch_size=32, epochs=10,
              validation_data=(x_test, y_test))


output: (32, 14, 14, 1)
Epoch 1/10
output: (32, 14, 14, 1)
output: (32, 14, 14, 1)

InvalidArgumentError: ignored