In [1]:
import tensorflow as tf
import set_data

In [2]:
def conv_block(input, num_kernel, training):
    net = tf.layers.conv2d(input, num_kernel, kernel_size=[3, 3], strides=2, padding='SAME', activation=tf.nn.relu)
    net = tf.layers.batch_normalization(net, training=training)
    return net

In [3]:
def deconv_block(input, num_kernel, training):
    net = tf.layers.conv2d_transpose(input, num_kernel, kernel_size=[3, 3], strides=2, padding='SAME', activation=tf.nn.relu)
    net = tf.layers.batch_normalization(net, training=training)
    return net

In [4]:
class AE():
    def __init__(self):
        self.graph = tf.Graph()
        self.batch_size = 128
        with self.graph.as_default():
            with tf.device('/cpu:0'):
                train_datasets, _ = set_data.create_mnist_dataset(self.batch_size, 'train')
                iterator = train_datasets.make_one_shot_iterator()
                self.images, labels =iterator.get_next()
            self.generated_images = self._build_graph(self.images, 5, reuse=tf.AUTO_REUSE, training=True)
            self.loss = self._loss_function(self.images, self.generated_images)
            self.solver = tf.train.AdamOptimizer(learning_rate=0.0001) \
                           .minimize(self.loss)
            initializer = tf.global_variables_initializer()
            self.sess = tf.Session()
            self.sess.run(initializer)
    def train(self):
        for i in range(101):
            loss, _, np_real_images, np_generated_images = self.sess.run(
                    [self.loss, self.solver, self.images, self.generated_images])
            if i % 10 == 0:
                print("iterator {} : loss {} ".format(i, loss))
            if i % 1000 == 0:
                self.visualize(np_real_images, np_generated_images)
            
    def visualize(self, images, generated_images):
        pass
    def _build_graph(self, input, dim_code, reuse=tf.AUTO_REUSE, training=False):
        with tf.variable_scope('encoder', reuse = reuse):
            print(input)
            net = conv_block(input, 32, training) #28x28 -> 14x14
            net = conv_block(net, 32, training) #14x14 -> 7x7
            net = conv_block(net, 64, training) #7x7 -> 3x3
            net = tf.layers.flatten(net) # 3x3 -> 9
            latent_var = tf.layers.dense(net, dim_code) # 9 -> dim_code
        with tf.variable_scope('decoder', reuse = reuse):
            net = tf.layers.dense(latent_var, 7*7)
            net = tf.reshape(shape=[-1, 7, 7, 1], tensor=net)
            net = deconv_block(net, 32, training) #7x7 -> 14x14
            net = deconv_block(net, 64, training) #14x14 -> 28x28
            net = tf.layers.conv2d_transpose(net, 1, [3,3], padding = "SAME", activation=None)
            generated = tf.nn.tanh(net)
        return generated
    def _loss_function(self, _real_images, _generated_images):
        recon_loss = tf.reduce_mean(tf.square(_real_images - _generated_images))
        return recon_loss

In [5]:
x = AE()

Tensor("IteratorGetNext:0", shape=(?, 28, 28, 1), dtype=float32, device=/device:CPU:0)


In [6]:
x.train()

iterator 0 : loss 1.6106016635894775 
iterator 10 : loss 1.37012779712677 
iterator 20 : loss 1.1769214868545532 
iterator 30 : loss 1.0598185062408447 


KeyboardInterrupt: 