<a href="https://colab.research.google.com/github/DhruvSrikanth/Model-Pipelines/blob/master/skip_connections_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [26]:
# ----------------------------------------------Import required Modules----------------------------------------------- #
import tensorflow as tf
from tensorflow.keras.applications import VGG16, ResNet50, DenseNet121
from tensorflow.keras.initializers import glorot_uniform

In [27]:
def conv_block(X, f, filters, s = 2):

    F1, F2, F3 = filters
    
    X_shortcut = X

    X = tf.keras.layers.Conv2D(F1, kernel_size = (1, 1), strides = (s,s), kernel_initializer = glorot_uniform(seed=0))(X)
    X = tf.keras.layers.BatchNormalization(axis = 3)(X)
    X = tf.keras.activations.relu(X)
    print('main path (post 1st conv) shape = ', X.shape)

    X = tf.keras.layers.Conv2D(filters = F2, kernel_size = (f, f), strides = (1,1), padding = 'same', kernel_initializer = glorot_uniform(seed=0))(X)
    X = tf.keras.layers.BatchNormalization(axis = 3)(X)
    X = tf.keras.activations.relu(X)
    print('main path (post 2nd conv) shape = ', X.shape)

    X = tf.keras.layers.Conv2D(filters = F3, kernel_size = (1, 1), strides = (1,1), padding = 'valid', kernel_initializer = glorot_uniform(seed=0))(X)
    X = tf.keras.layers.BatchNormalization(axis = 3)(X)
    print('main path (post 3rd conv) shape = ', X.shape)

    X_shortcut = tf.keras.layers.Conv2D(filters = F3, kernel_size = (1, 1), strides = (s,s), padding = 'valid', kernel_initializer = glorot_uniform(seed=0))(X_shortcut)
    X_shortcut = tf.keras.layers.BatchNormalization(axis = 3)(X_shortcut)
    print('shortcut  (post 1st conv) shape = ', X_shortcut.shape)

    X = tf.keras.layers.Add()([X, X_shortcut])
    X = tf.keras.activations.relu(X)
    
    return X

In [28]:
def deconv_block(X, f, filters, s = 2):

    F1, F2, F3 = filters
    
    X_shortcut = X

    X = tf.keras.layers.Conv2DTranspose(F1, kernel_size = (1, 1), strides = (s,s), kernel_initializer = glorot_uniform(seed=0))(X)
    X = tf.keras.layers.BatchNormalization(axis = 3)(X)
    X = tf.keras.activations.relu(X)
    print('main path (post 1st conv) shape = ', X.shape)

    X = tf.keras.layers.Conv2DTranspose(filters = F2, kernel_size = (f, f), strides = (1,1), padding = 'same', kernel_initializer = glorot_uniform(seed=0))(X)
    X = tf.keras.layers.BatchNormalization(axis = 3)(X)
    X = tf.keras.activations.relu(X)
    print('main path (post 2nd conv) shape = ', X.shape)

    X = tf.keras.layers.Conv2DTranspose(filters = F3, kernel_size = (1, 1), strides = (1,1), padding = 'valid', kernel_initializer = glorot_uniform(seed=0))(X)
    X = tf.keras.layers.BatchNormalization(axis = 3)(X)
    print('main path (post 3rd conv) shape = ', X.shape)

    X_shortcut = tf.keras.layers.Conv2DTranspose(filters = F3, kernel_size = (1, 1), strides = (s,s), padding = 'valid', kernel_initializer = glorot_uniform(seed=0))(X_shortcut)
    X_shortcut = tf.keras.layers.BatchNormalization(axis = 3)(X_shortcut)
    print('shortcut  (post 1st conv) shape = ', X_shortcut.shape)

    X = tf.keras.layers.Add()([X, X_shortcut])
    X = tf.keras.activations.relu(X)
    
    return X

In [29]:
inp = tf.keras.layers.Input(shape=(224, 224, 3))
print('Input shape = ', inp.shape)

Input shape =  (None, 224, 224, 3)


In [30]:
X = conv_block(inp, f=3, filters=[16, 16, 64], s=2)

main path (post 1st conv) shape =  (None, 112, 112, 16)
main path (post 2nd conv) shape =  (None, 112, 112, 16)
main path (post 3rd conv) shape =  (None, 112, 112, 64)
shortcut  (post 1st conv) shape =  (None, 112, 112, 64)


In [31]:
# Observations - (CONV)
# kernel not controlling anything in input shape
# filters controlling the 4th dim in input shape -> filter value = 3rd dim 
# stride controlling the 2nd and 3rd dims in input shape -> stride value reduces 1st and 2nd dim by a multiplier of 2x, 3x etc.

In [32]:
out = X
print(out.shape)

(None, 112, 112, 64)


In [33]:
X = deconv_block(out, f=3, filters=[64, 16, 16], s=2)

main path (post 1st conv) shape =  (None, 224, 224, 64)
main path (post 2nd conv) shape =  (None, 224, 224, 16)
main path (post 3rd conv) shape =  (None, 224, 224, 16)
shortcut  (post 1st conv) shape =  (None, 224, 224, 16)


In [34]:
# Observations - (DECONV)
# kernel not controlling anything in input shape
# filters controlling the 4th dim in input shape -> filter value = 3rd dim 
# stride controlling the 2nd and 3rd dims in input shape -> stride value acts as multiplier for 1st and 2nd dim i.e 1x, 2x etc.


# I think for each layer we will have to fiddle around with the padding and stride to change the shape
# filter and kernel size can probably remain the same to what we were doing before since that doesnt influence anything much here and also that way we can maintain the overall size of our model (working towards that higher accuracy, lower resource usage, same size model goal)

In [35]:
# MODEL

In [38]:
# ----------------------------------------------Define Model---------------------------------------------------------- #

# Build complete autoencoder model
def build_autoencoder(input_shape = (224, 224, 3), describe = False):
    '''
    Build Autoencoder Model.\n
    :param input_shape: Input Shape passed to Autoencoder Model (224,224,3) (default)\n
    :return: Autoencoder Model
    '''
    def encoder(inp, input_shape=input_shape):
        '''
        Build Encoder Model.\n
        :param inp: Input to Autoencoder Model\n
        :param input_shape: Input Shape passed to Autoencoder Model (224,224,3) (default)\n
        :return: Encoder Model
        '''

        cnn_model = ResNet50(include_top = False, weights = "imagenet", input_shape = input_shape, pooling = "none")
        cnn_model.trainable = False
        pre_trained = tf.keras.models.Model(inputs = cnn_model.input, outputs = cnn_model.get_layer(name="conv3_block1_out").output, name = "resnet")

        # https://keras.io/guides/transfer_learning/
        x = pre_trained(inputs = inp, training=False)
        # print(pre_trained.summary())

        layer10 = tf.keras.layers.Conv2D(filters = 512, kernel_size = 1, name = "conv10")(x) # for pix2vox-A(large), kernel_size is 3
        layer10_norm = tf.keras.layers.BatchNormalization(name="layer10_norm")(layer10)
        layer10_elu = tf.keras.activations.elu(layer10_norm, name="layer10_elu")

        layer11 = tf.keras.layers.Conv2D(filters = 256, kernel_size = 3, name = "conv11")(layer10_elu) # for pix2vox-A(large), filters is 512
        layer11_norm = tf.keras.layers.BatchNormalization(name="layer11_norm")(layer11)
        layer11_elu = tf.keras.activations.elu(layer11_norm, name="layer11_elu")
        layer11_pool = tf.keras.layers.MaxPooling2D(pool_size = (4,4), name="layer11_pool")(layer11_elu) # for pix2vox-A(large), kernel size is 3

        layer12 = tf.keras.layers.Conv2D(filters = 128, kernel_size = 3, name = "conv12")(layer11_pool) # for pix2vox-A(large), filters is 256, kernel_size is 1
        layer12_norm = tf.keras.layers.BatchNormalization(name="layer12_norm")(layer12)
        layer12_elu = tf.keras.activations.elu(layer12_norm, name="layer12_elu")

        return layer12_elu

    def decoder(inp):
        '''
        Build Decoder Model.\n
        :param inp: Reshaped Output of Encoder Model\n
        :return: Decoder Model
        '''
        layer1 = tf.keras.layers.Convolution3DTranspose(filters=128, kernel_size=4, strides=(2,2,2), padding="same", use_bias=False, name="Conv3D_1")(inp)
        layer1_norm = tf.keras.layers.BatchNormalization(name="layer1_norm")(layer1)
        layer1_relu = tf.keras.activations.relu(layer1_norm, name="layer1_relu")

        layer2 = tf.keras.layers.Convolution3DTranspose(filters=64, kernel_size=4, strides=(2,2,2), padding="same", use_bias=False, name="Conv3D_2")(layer1_relu)
        layer2_norm = tf.keras.layers.BatchNormalization(name="layer2_norm")(layer2)
        layer2_relu = tf.keras.activations.relu(layer2_norm, name="layer2_relu")

        layer3 = tf.keras.layers.Convolution3DTranspose(filters=32, kernel_size=4, strides=(2,2,2), padding="same", use_bias=False, name="Conv3D_3")(layer2_relu)
        layer3_norm = tf.keras.layers.BatchNormalization(name="layer3_norm")(layer3)
        layer3_relu = tf.keras.activations.relu(layer3_norm, name="layer3_relu")

        layer4 = tf.keras.layers.Convolution3DTranspose(filters=8, kernel_size=4, strides=(2,2,2), padding="same", use_bias=False, name="Conv3D_4")(layer3_relu)
        layer4_norm = tf.keras.layers.BatchNormalization(name="layer4_norm")(layer4)
        layer4_relu = tf.keras.activations.relu(layer4_norm, name="layer4_relu")

        layer5 = tf.keras.layers.Convolution3DTranspose(filters=1, kernel_size=1, use_bias=False, name="Conv3D_5")(layer4_relu)
        layer5_sigmoid = tf.keras.activations.sigmoid(layer5, name="layer5_sigmoid")

        # TODO: check this statement
        layer5_sigmoid = tf.keras.layers.Reshape((32,32,32))(layer5_sigmoid)

        return layer5_sigmoid

    # Input
    input = tf.keras.Input(shape = input_shape, name = "input_layer")

    # Encoder Model
    encoder_model = tf.keras.Model(input, encoder(input), name = "encoder")
    if describe:
        print("\nEncoder Model Summary:\n")
        encoder_model.summary()

    # Decoder Input Reshaped from Encoder Output
    decoder_input = tf.keras.Input(shape=(2, 2, 2, 256), name = "decoder_input")

    # Decoder Model
    decoder_model = tf.keras.Model(decoder_input, decoder(decoder_input), name = "decoder")
    if describe:
        print("\nDecoder Model Summary:\n")
        decoder_model.summary()

    # Autoencoder Model
    encoder_output = encoder_model(input)
    # the encoder output should be reshaped to (-1,2,2,2,256) to be fed into decoder
    decoder_input = tf.keras.layers.Reshape((2,2,2,256))(encoder_output)

    autoencoder_model = tf.keras.Model(input, decoder_model(decoder_input), name ='autoencoder')
    if describe:
        print("\nAutoencoder Model Summary:\n")
        autoencoder_model.summary()

    return autoencoder_model

In [39]:
autoencoder_model = build_autoencoder()
print(autoencoder_model.summary())

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
Model: "autoencoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_layer (InputLayer)     [(None, 224, 224, 3)]     0         
_________________________________________________________________
encoder (Functional)         (None, 4, 4, 128)         2354176   
_________________________________________________________________
reshape_5 (Reshape)          (None, 2, 2, 2, 256)      0         
_________________________________________________________________
decoder (Functional)         (None, 32, 32, 32)        2769832   
Total params: 5,124,008
Trainable params: 4,508,760
Non-trainable params: 615,248
_________________________________________________________________
None
