In [1]:
from sys import path
path.append('src/')

import numpy as np
import tensorflow as tf
import pickle as pkl
from keras.models import Model
from keras.layers import Input, Dense, Conv2D, Conv2DTranspose, MaxPool2D, Dropout, BatchNormalization, Activation, UpSampling2D
from keras.utils import plot_model

Using TensorFlow backend.


In [2]:
x_train, x_val = pkl.load(open('data/autoencoder_samples.pkl', 'rb'))

print(x_train.shape, x_train.dtype)
print(x_val.shape, x_val.dtype)

(5211, 232, 232, 3) float32
(571, 232, 232, 3) float32


In [3]:
x = np.concatenate((x_train, x_val), axis=0)
print(x.shape)

(5782, 232, 232, 3)


In [4]:
def conv_bn(idx, prev_input, filters):
    conv_name = 'conv_' + str(idx)
    bn_name = 'bn_' + str(idx)
    act_name = 'relu_' + str(idx)
    
    conv = Conv2D(filters, kernel_size=(3, 3), strides=(1, 1), padding='same', activation=None, name=conv_name)(prev_input)
    conv = BatchNormalization(axis=-1, name=bn_name)(conv)
    conv = Activation('relu', name=act_name)(conv)
    return conv

def conv_bn_maxpool(idx, prev_input,  filters):
    pool_name = 'pool_' + str(idx)
    
    conv = conv_bn(idx, prev_input, filters)
    pool = MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding='valid', name=pool_name)(conv)
    return pool

def up_conv(idx, prev_input, filters):
    up_name = str(idx) + '_up' 
    conv_name = str(idx) + '_conv' 
    
    deconv = UpSampling2D(size=(2, 2), name=up_name)(prev_input)
    deconv = Conv2D(filters, kernel_size=(3, 3), strides=(1, 1), padding='same', activation=None, name=conv_name)(deconv)
    return deconv

def up_conv_bn(idx, prev_input, filters):
    up_name = str(idx) + '_up' 
    
    deconv = UpSampling2D(size=(2, 2), name=up_name)(prev_input)
    deconv = deconv_bn(idx, deconv, filters)
    return deconv

def deconv_bn(idx, prev_input, filters):
    bn_name = str(idx) + '_bn' 
    act_name = str(idx) + '_relu' 
    conv_name = str(idx) + '_conv' 
    
    deconv = Conv2D(filters, kernel_size=(3, 3), strides=(1, 1), padding='same', activation=None, name=conv_name)(prev_input)
    deconv = BatchNormalization(axis=-1, name=bn_name)(deconv)
    deconv = Activation('relu', name=act_name)(deconv)
    return deconv

In [5]:
inputs = Input(shape=x.shape[1:])
pool_1 = conv_bn_maxpool(1, inputs, 8)
pool_2 = conv_bn_maxpool(2, pool_1, 16)
pool_3 = conv_bn_maxpool(3, pool_2, 32)
encode = conv_bn(4, pool_3, 64)

dec_4 = deconv_bn(4, encode, 32)
dec_3 = up_conv_bn(3, dec_4, 16)
dec_2 = up_conv_bn(2, dec_3, 8)
decoded = up_conv(1, dec_2, 3)

model = Model(inputs=inputs, outputs=decoded)
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 232, 232, 3)       0         
_________________________________________________________________
conv_1 (Conv2D)              (None, 232, 232, 8)       224       
_________________________________________________________________
bn_1 (BatchNormalization)    (None, 232, 232, 8)       32        
_________________________________________________________________
relu_1 (Activation)          (None, 232, 232, 8)       0         
_________________________________________________________________
pool_1 (MaxPooling2D)        (None, 116, 116, 8)       0         
_________________________________________________________________
conv_2 (Conv2D)              (None, 116, 116, 16)      1168      
_________________________________________________________________
bn_2 (BatchNormalization)    (None, 116, 116, 16)      64        
__________