# 오토인코더

In [1]:
import os

In [None]:
from tensorflow.keras.datasets import mnist

def load_mnist():
    (x_train, y_train), (x_test, y_test) = mnist.load_data()

    x_train = x_train.astype('float32') / 255.
    x_test  = x_test.astype('float32') / 255.

    x_train = x_train.reshape(x_train.shape + (1, ))
    x_test  = x_test.reshape(x_test.shape + (1, ))

    return (x_train, y_train), (x_test, y_test)

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, LeakyReLU, BatchNormalization, Dropout, Flatten

class AutoEncoder():
    def __init__(self, 
                 input_dim, 
                 encoder_conv_filters, 
                 encoder_conv_kernel_size, 
                 encoder_conv_strides, 
                 decoder_conv_t_filters, 
                 decoder_conv_t_kernel_size,
                 decoder_conv_t_strides, 
                 z_dim, 
                 use_batch_norm=False, 
                 use_dropout=False):
        
        self.name = 'AutoEncoder'

        self.input_dim = input_dim
        self.encoder_conv_filters = encoder_conv_filters
        self.encoder_conv_kernel_size = encoder_conv_kernel_size
        self.encoder_conv_strides = encoder_conv_strides
        self.decoder_conv_t_filters = decoder_conv_t_filters
        self.decoder_conv_t_kernel_size = decoder_conv_t_kernel_size
        self.decoder_conv_t_strides = decoder_conv_t_strides
        self.z_dim = z_dim
        self.use_batch_norm = use_batch_norm
        self.use_dropout = use_dropout

        self.n_layers_encoder = len(encoder_conv_filters)
        self.n_layers_decoder = len(decoder_conv_t_filters)

        self._build()
    
    def _build(self):

        ### BUILD ENCODER ###
        encoder_input = Input(shape=self.input_dim, name='encoder_input')

        x = encoder_input

        for i in range(self.n_layers_encoder):
            conv_layer = Conv2D(
                filters=self.encoder_conv_filters[i], 
                kernel_size=self.encoder_conv_kernel_size[i], 
                strides=self.encoder_conv_strides[i], 
                padding='same', 
                name=f'encoder_conv_{i}'
            )
            x = conv_layer(x)
            x = LeakyReLU()(x)

            if self.use_batch_norm:
                x = BatchNormalization()(x)
            if self.use_dropout:
                x = Dropout(rate=.25)(x)
        
        shape_before_flattening = tf.shape(x)[1:]

        x = Flatten()(x)
