In [None]:
import tensorflow as tf
from tensorflow.keras import layers

In [None]:
from tensorflow.keras import datasets
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
x_train = tf.pad(x_train, [[0, 0], [2, 2], [2, 2]])/255
x_test = tf.pad(x_test, [[0, 0], [2, 2], [2, 2]])/255
x_train = tf.expand_dims(x_train, axis=3, name=None)
x_test = tf.expand_dims(x_test, axis=3, name=None)
x_train = tf.repeat(x_train, 3, axis=3)
x_test = tf.repeat(x_test, 3, axis=3)
x_val = x_train[-2000:, :, :, :]
y_val = y_train[-2000:]
x_train = x_train[:-2000, :, :, :]
y_train = y_train[:-2000]

In [None]:
class conv_block(tf.keras.layers.Layer):
    def __init__(self, num_channels: int):
        super(conv_block, self).__init__()
        self.conv = layers.Conv2D(num_channels, kernel_size=3, padding='same')
        self.bn = layers.BatchNormalization()

    def call(self, inp):
        x = self.conv(inp)
        x = self.bn(x)
        out = tf.nn.relu(x)
        return out

In [None]:
class DenseBlock(tf.keras.layers.Layer):
    def __init__(self, num_convs, num_channels):
        super(DenseBlock, self).__init__()
        self.layer_list = []
        for i in range(num_convs):
            self.layer_list.append(conv_block(num_channels))

    def call(self, x):
        for blk in self.layer_list:
            out = blk(x)
            x = layers.concatenate([x, out])
        return x

In [None]:
class TransitionBlock(tf.keras.layers.Layer):
    def __init__(self, num_channels: int):
        super(TransitionBlock, self).__init__()
        self.conv = layers.Conv2D(num_channels, kernel_size=1)
        self.bn = layers.BatchNormalization()
        self.ga = layers.AveragePooling2D(pool_size=2, strides=2)

    def call(self, inp):
        x = self.bn(inp)
        x = tf.nn.relu(x)
        x = self.conv(x)
        out = self.ga(x)
        return out

In [None]:
class DenseNet(tf.keras.Model):
    def __init__(self, num_channels=64, growth_rate=32, arch=(4, 4, 4, 4), num_classes=10):
        super(DenseNet, self).__init__()
        self.num_channels = num_channels
        self.growth_rate = growth_rate
        
        self.b1 = tf.keras.Sequential([
            layers.Conv2D(64, kernel_size=(7, 7), strides=(2, 2), padding='same'),
            layers.BatchNormalization(),
            layers.ReLU(),
            layers.MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding='same')
        ])
        
        self.dense_blocks = []
        self.transition_blocks = []
        
        for i, num_convs in enumerate(arch):
            self.dense_blocks.append(DenseBlock(num_convs, growth_rate))
            self.num_channels += num_convs * growth_rate
            if i != len(arch) - 1:
                self.transition_blocks.append(TransitionBlock(self.num_channels // 2))
                self.num_channels //= 2
                
        self.last = tf.keras.Sequential([
            layers.BatchNormalization(),
            layers.ReLU(),
            layers.GlobalAveragePooling2D(),
            layers.Flatten(),
            layers.Dense(num_classes, activation='softmax')
        ])
        
    def call(self, x):
        x = self.b1(x)
        for dense_block, transition_block in zip(self.dense_blocks, self.transition_blocks):
            x = dense_block(x)
            x = transition_block(x)

        x = self.last(x)
        return x

In [None]:
model = tf.keras.Sequential([DenseNet()])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, validation_data=(x_val, y_val), batch_size=64, epochs=10)