In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import *

## Simple Custom Dense Layer

In [2]:
class CustomDense(layers.Layer):
    def __init__(self, units=32):
        super(CustomDense, self).__init__()
        self.units = units
        
    def build(self, input_shape):
        w_init = tf.random_normal_initializer()
        self.w = tf.Variable(name='kernel', initial_value=w_init(shape=(input_shape[-1], self.units),
                             dtype='float32'),trainable=True)
        
        b_init = tf.zeros_initializer()
        self.b = tf.Variable(name='bias', initial_value=b_init(shape=self.units, dtype='float32'), trainable=True)
    
    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b

### Basic Demo

In [3]:
inputs = tf.constant([[5]], dtype='float32')
inputs

Metal device set to: Apple M1


2022-01-30 11:49:21.891800: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2022-01-30 11:49:21.891885: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[5.]], dtype=float32)>

In [4]:
custom_dense = CustomDense(units=1)

In [5]:
custom_dense(inputs)

<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[-0.09382155]], dtype=float32)>

In [6]:
custom_dense.variables

[<tf.Variable 'custom_dense/kernel:0' shape=(1, 1) dtype=float32, numpy=array([[-0.01876431]], dtype=float32)>,
 <tf.Variable 'custom_dense/bias:0' shape=(1,) dtype=float32, numpy=array([0.], dtype=float32)>]

## Load and Preprocess Data

In [7]:
(X_train, y_train), (X_test, y_test) = datasets.mnist.load_data()

In [8]:
X_train = X_train.astype('float32') / 255.0
X_test = X_test.astype('float32') / 255.0