In [1]:
import numpy as np
import pandas as pd
import os
from tqdm import tqdm
import tensorflow.contrib.slim as slim
import tensorflow as tf

# define network(class)

In [2]:
class SCAE(object):
    """
    Stacked Convolutional AutoEncoder
    """
    def __init__(self, mode='train', batch_size= 128, learning_rate = 0.0002):
        self.mode = mode
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.input_array = tf.placeholder(tf.float32, [None, 32, 32, 3], 'input_image')
        self.reconstruct = self.build_model()
        
        self.loss = tf.reduce_mean(tf.square(self.reconstruct - self.input_array))

    def Encoder(self, input_array, reuse = False):
        # input : (batch_size, 32, 32, 3)
        with tf.variable_scope('Encoder', reuse=reuse):
            with slim.arg_scope([slim.conv2d], padding='SAME', activation_fn=None, 
                                stride = 2, weights_initializer=tf.contrib.layers.xavier_initializer()):
                with slim.arg_scope([slim.batch_norm], decay=0.95, center=True, scale=True,
                                    updates_collections = None, activation_fn=tf.nn.relu,
                                    is_training =(self.mode =='train' or self.mode=='pretrain')):
                    net = slim.conv2d(input_array, 6, [3,3], scope='e_conv1') # (batch_size, 16, 16, 2)
                    net = slim.batch_norm(net, scope='e_bn1')
                    net = slim.conv2d(net, 12, [3,3], scope='e_conv2') # (batch_size, 8, 8, 12)
                    net = slim.batch_norm(net,activation_fn=tf.nn.tanh, scope='e_bn2')
                    if self.mode == 'pretrain':
                        net = slim.conv2d(net, 2, [1,1],  scope='out')
                        net = slim.flatten(net)
                    return net

    def Decoder(self, z, reuse=False):
        # z : (batch_size, 8, 8, 12)
        with tf.variable_scope('Decoder', reuse=reuse):
            with slim.arg_scope([slim.conv2d_transpose], padding='SAME', activation_fn=None,
                               stride = 2, weights_initializer=tf.contrib.layers.xavier_initializer()):
                with slim.arg_scope([slim.batch_norm], decay = 0.95, center=True, scale=True,
                                    updates_collections = None, activation_fn=tf.nn.relu,
                                    is_training = (self.mode == 'train')):
                    net = slim.conv2d_transpose(z, 6, [3,3], scope='d_conv1') # (batch_size, 16, 16, 6)
                    net = slim.batch_norm(net,scope='d_bn1')
                    net = slim.conv2d_transpose(net, 3, [3,3], scope='d_conv2') # (batch_size, 32, 32, 3)
                    net = slim.batch_norm(net,activation_fn = tf.nn.tanh,scope='d_bn9')
                    return net
                    
    def build_model(self):
        self.embedding =self.Encoder(self.input_array)
        self.recon = self.Decoder(self.embedding)
        return self.recon

In [3]:
class Solver(object):
    def __init__(self, model, batch_size = 128, 
                 train_iter = 10000,
                 log_dir = 'logs', 
                 model_save_path='model',
                 test_model='model/train-9500'):
        self.model = model
        self.batch_size = batch_size
        self.train_iter = train_iter
        self.log_dir = log_dir
        self.model_save_path = model_save_path
        self.test_model = test_model
        self.config = tf.ConfigProto(log_device_placement=True)
        self.config.gpu_options.allow_growth=True
        self.data = np.load('image_32.npy').item()
    
    def batch_generator(self, shuffle = True):
        id_ = np.array(list(self.data.keys()))
        
        data_size = len(id_)
        num_batches_per_epoch = int(data_size/ self.batch_size)+1
        
        for epoch in range(num_batches_per_epoch):
            if shuffle:
                shuffle_indices = np.random.permutation(np.arange(data_size))
                id_shuffle = id_[shuffle_indices]
            else:
                id_shuffle = id_[shuffle_indices]
                
            for batch_num in range(num_batches_per_epoch):
                start_index = batch_num * batch_size
                end_index = (batch_num +1) * batch_size
                
                if end_index < data_size:
                    tmp_list = []
                    id_batch = id_shuffle[start_index:end_index]
                    for i in id_batch:
                        tmp_list.append(self.data[i])
                    yield tmp_list

    def train(self):
        with tf.Graph().as_default():
            sess = tf.Session(config = self.config)
            with sess.as_default():
                scae = self.model
#                 scae.build_model()
        
                if tf.gfile.Exists(self.log_dir):
                    tf.gfile.DeleteRecursively(self.log_dir)
                tf.gfile.MakeDirs(self.log_dir)
                
                global_step = tf.Variable(0, name = 'global_step', trainable = False)
                opt = tf.train.AdamOptimizer(scae.learning_rate)
                optimizer = opt.minimize(scae.loss, global_step = global_step)
                
                # omit varibale for tensorborad (about summary)
                
                saver = tf.train.Saver(tf.global_variables(), max_to_keep = 4)
                sess.run(tf.global_variables_initializer())
                
                batch_train = self.batch_generator()
                
                for data in batch_train:
                    feed_dict = {scae.input_array : data}
                    current_step = sess.run(global_step , feed_dict = feed_dict)
                    optimizer.run(feed_dict = feed_dict)
                    if current_step % 100 == 0:
                        print("step: {}".format(current_step))
                        print("==validation start==")
                        batch_val = self.batch_generator()

                        losses = []
                        for val_data in batch_val:
                            feed_dict = {scae.input_array : val_data}
                            l = scae.loss.eval(feed_dict = feed_dict)
                            losses.append(l)
                        print("Mean loss  = " + str(sum(losses)/ len(losses)))
                        saver.save(sess, save_path = self.model_save_path, global_step = current_step)
                        print("==training==")
                        
                        
                if(step+1) %500 == 0:
                    loss = sess.run(model.loss, feed_dict)
                    test_loss = sess.run(model.loss, feed_dict={model.spectrogram: test_spec})
                    print("Step: [%d/%d] loss: [%.3f] test loss: [%.3f]"%(step+1, self.train_iter, loss, test_loss))
                    saver.save(sess, os.path.join(self.model_save_path, 'train'), global_step = step+1)
                    print('train-%d saved..!'%(step+1))
                    

#     def test(self):
#         model = self.model
#         model.build_model()
        
#         with tf.Session(config = self.config) as sess:
#             print("loading test model")
#             saver = tf.train.Saver()
#             saver.restore(sess, self.test_model)

#             for i in range(5):
#                 batch_spec, _ = self.batch_fft(10)
#                 feed_dict = {model.test_spectrogram : batch_spec}
#                 sample_batch_recon = sess.run(model.test_recon, feed_dict)
#                 fig, axs = plt.subplots(2, 5, figsize=(15, 4))
#                 for example_i in range(5):
#                     axs[0][example_i].matshow(np.reshape(batch_spec[example_i, :], (512, 512))) #, cmap=plt.get_cmap('gray'))
#                     axs[1][example_i].matshow(np.reshape(sample_batch_recon[example_i, :], (512,512))) #, cmap=plt.get_cmap('gray'))
#                 plt.show()

In [4]:
model = SCAE('train', learning_rate=0.005, batch_size = 128)
solver = Solver(model)
    
solver.train()

ValueError: No variables to optimize.

In [None]:
model = SCAE('test', learning_rate=0.0003, batch_size = 16)
solver = Solver(model)
    
solver.test()