In [1]:
import tensorflow as tf
from tensorflow.keras import layers

from GDN_Pablo import GDN as GDNP
from GDN_Jorge import GDN as GDNJ

In [2]:
def generate_perceptnet(kernel_size=3, normalization='gdn'):
    if normalization == 'batch_norm':
        norm = [
            # Hyperparams from PyTorch layers
            layers.BatchNormalization(momentum=0.1, epsilon=1e-5) for i in range(4)
        ]
    elif normalization == 'gdn':
        norm = [
            GDNJ(kernel_size, apply_independently=True),
            GDNJ(kernel_size, apply_independently=False),
            GDNJ(kernel_size, apply_independently=False),
            GDNJ(kernel_size, apply_independently=False)
        ]
    elif normalization == 'instance_norm':
        pass

    return norm

In [3]:
def generate_perceptnet():
    return tf.keras.Sequential([
        GDNJ(kernel_size=1, apply_independently=True),
        layers.Conv2D(filters=3, kernel_size=1, strides=1, padding='same'),
        layers.MaxPool2D(2),
        GDNJ(kernel_size=1),
        layers.Conv2D(filters=6, kernel_size=5, strides=1, padding='same'),
        layers.MaxPool2D(2),
        GDNJ(kernel_size=1),
        layers.Conv2D(filters=128, kernel_size=5, strides=1, padding='same'),
        GDNJ(kernel_size=1)
    ])

In [6]:
pnet = generate_perceptnet()
pnet.build(input_shape=(1,28,28,1))
pnet.summary()

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
gdn_8 (GDN)                  (1, 28, 28, 1)            2         
_________________________________________________________________
conv2d_6 (Conv2D)            (1, 28, 28, 3)            6         
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (1, 14, 14, 3)            0         
_________________________________________________________________
gdn_9 (GDN)                  (1, 14, 14, 3)            12        
_________________________________________________________________
conv2d_7 (Conv2D)            (1, 14, 14, 6)            456       
_________________________________________________________________
max_pooling2d_5 (MaxPooling2 (1, 7, 7, 6)              0         
_________________________________________________________________
gdn_10 (GDN)                 (1, 7, 7, 6)             