In [1]:
import os
import numpy as np

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

from tensorflow.keras.layers import (Dense, 
                                     BatchNormalization, 
                                     LeakyReLU, 
                                     Reshape, 
                                     Conv2DTranspose,
                                     Conv2D,
                                     Dropout,
                                     Flatten)

import tensorflow_addons as tfa

from tensorflow.keras.applications import EfficientNetB3
#from tensorflow.keras.applications import EfficientNetB7
from keras.models import Model
from keras.models import Sequential
import efficientnet.tfkeras
import efficientnet.keras as efn 

In [18]:
class ReflectionPadding2D(layers.Layer):
    """Implements Reflection Padding as a layer.

    Args:
        padding(tuple): Amount of padding for the
        spatial dimensions.

    Returns:
        A padded tensor with the same type as the input tensor.
    """


    def __init__(self, padding=(1, 1), **kwargs):
        self.padding = tuple(padding)
        super(ReflectionPadding2D, self).__init__(**kwargs)

    def call(self, input_tensor, mask=None):
        padding_width, padding_height = self.padding
        padding_tensor = [
            [0, 0],
            [padding_height, padding_height],
            [padding_width, padding_width],
            [0, 0],
        ]
        return tf.pad(input_tensor, padding_tensor, mode="REFLECT")


In [79]:
def efficientnet():
    image_resize = 256
    model = efn.EfficientNetB3(include_top=False, weights="imagenet", input_shape=(256, 256, 3))

    #model.trainable = False
    inputs = tf.keras.layers.Input(shape=[256,256,3])
    
    model = Model(model.input, model.layers[-40].output)
    
    """
    x = Flatten()(model)(model)
    x = Dense(4096)(model)(x)
    x = Reshape((image_resize, image_resize, 3))(x)
    
    model = keras.models.Model(inputs, x, name=name)
    """

    model.compile()
    model.summary()
    return model

In [23]:
def efficientnet():
    image_resize = 256
    model = efn.EfficientNetB3(include_top=False, weights="imagenet", input_shape=(256, 256, 3))
    
    ### If we want to freeze these pretrained weights:

    #for layer in model.layers:
    #	layer.trainable = False     # mark loaded layers as not trainable

    # add new classifier layers
    #flat = layers.Flatten()(model.layers[-35].output)
    x = layers.Conv2DTranspose(filters=64, kernel_size=(3, 3), strides=(4, 4), use_bias=False)(model.layers[-36].output)
    x = tfa.layers.InstanceNormalization()(x)
    x = layers.Conv2DTranspose(filters=64, kernel_size=(3, 3), strides=(4, 4), use_bias=False)(x)
    x = tfa.layers.InstanceNormalization()(x)
    x = layers.Conv2DTranspose(filters=64, kernel_size=(3, 3), strides=(2, 2), use_bias=False)(x)
    x = tfa.layers.InstanceNormalization()(x)
    x = ReflectionPadding2D(padding=(3, 3))(x)
    x = layers.Conv2D(3, (8, 8), padding="valid")(x)
    x = layers.Activation("tanh")(x)

    # define new model
    model = Model(inputs=model.inputs, outputs=x)

    model.compile()
    model.summary()
    return model

In [24]:
efficientnet()

Model: "functional_19"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_11 (InputLayer)           [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
stem_conv (Conv2D)              (None, 128, 128, 40) 1080        input_11[0][0]                   
__________________________________________________________________________________________________
stem_bn (BatchNormalization)    (None, 128, 128, 40) 160         stem_conv[0][0]                  
__________________________________________________________________________________________________
stem_activation (Activation)    (None, 128, 128, 40) 0           stem_bn[0][0]                    
______________________________________________________________________________________

<tensorflow.python.keras.engine.functional.Functional at 0x1e234f40610>