# CIFAR10 Conv Autoencoder

In [None]:
import os
import tensorflow as tf
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
import warnings
import time
import cPickle
warnings.filterwarnings("ignore", category=DeprecationWarning)  ## just for ignore DeprcationWarning message
print("Current version [%s]" %(tf.__version__))
print("Packages Loaded")

### Configurations

In [None]:
# Dataset Configurations
tf.app.flags.DEFINE_integer('img_size', 32, """Image size of CIFAR-10 dataset""")
tf.app.flags.DEFINE_integer('img_num', 10000, """Number of images in one cifar batch""")
tf.app.flags.DEFINE_integer('batch_num', 5, """Number of cifar batches in dataset""")
tf.app.flags.DEFINE_string('train_dir', './../../../Dataset/cifar-10-batches-py', """Directory which contains the train data""")
tf.app.flags.DEFINE_string('test_dir', './../../../Dataset/cifar-10-batches-py', """Directory which contains the test data""")

# Network Configurations
tf.app.flags.DEFINE_integer('batch_size', 100, """Number of images to process in a batch""")
tf.app.flags.DEFINE_float('l1_ratio', 0.5, """Ratio of level1""")
tf.app.flags.DEFINE_float('l2_ratio', 0.5, """Ratio of level2""")

# Optimization Configurations
tf.app.flags.DEFINE_float('lr', 0.001, """Learning rate""")

# Training Configurations
tf.app.flags.DEFINE_integer('training_epochs', 2000, """Number of epochs to run""")
tf.app.flags.DEFINE_integer('display_step', 10, """Number of iterations to display training output""")
tf.app.flags.DEFINE_integer('save_step', 20, """Number of interations to save checkpoint""")
tf.app.flags.DEFINE_integer('save_max', 10, """Number of checkpoints to remain""")


# Save Configurations
tf.app.flags.DEFINE_string('nets', './nets', """Directory where to write the checkpoints""")
tf.app.flags.DEFINE_string('outputs', './outputs', """Directory where to save the output images""")
tf.app.flags.DEFINE_string('tboard', './tensorboard', """Directory where to save the tensorboard logs""")


FLAGS = tf.app.flags.FLAGS
print("FLAGS READY")

### GPU control

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"]="0"
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.allow_soft_placement = True
config.gpu_options.per_process_gpu_memory_fraction = 0.3

### Load Data

In [None]:
def unpickle(file):
    with tf.device('/CPU:0'):
        with open(file, 'rb') as fo:
            dict = cPickle.load(fo)
        return dict

def read_cifar(file):
    with tf.device('/CPU:0'):
        _dic = unpickle(file)
        _img = _dic['data']/255.    # float type
        _label = _dic['labels']    # (10000, )

#         _img_shape = np.shape(_img)
#         _img = np.reshape(np.transpose(np.reshape(_img, (-1, 3, 32, 32)), (0,2,3,1)), _img_shape) # (10000, 3072)

        _img = np.transpose(np.reshape(_img, (-1, 3, 32, 32)), (0,2,3,1))
    
        return _img   # (10000, 32, 32, 3)
   

### Generating random noise mask

In [None]:
def noise_mask(prob=0.5):
    mask = np.zeros([FLAGS.img_size, FLAGS.img_size, 3])
    rd = np.random.random()
    if rd > prob:
        # threshold of the size of masks
        uthd = FLAGS.img_size    
        lthd = 0     
        # mask size should be beween 14x14, 5x5
        while(uthd>14 or lthd<5):
            ver1 = np.random.random_integers(0, FLAGS.img_size-1, size= 2)   # vertex1
            ver2 = np.random.random_integers(0, FLAGS.img_size-1, size= 2)    # vertex2
            uthd = np.maximum(np.abs(ver1[0]-ver2[0]), np.abs(ver1[1]-ver2[1]))    # upperbound
            lthd = np.minimum(np.abs(ver1[0]-ver2[0]), np.abs(ver1[1]-ver2[1]))    # lowerbound
        xmin = np.minimum(ver1[0], ver2[0])    # left x value
        xmax = np.maximum(ver1[0], ver2[0])    # right x value
        ymin = np.minimum(ver1[1], ver2[1])    # top y value
        ymax = np.maximum(ver1[1], ver2[1])    # bottom y value
        noise = np.random.random((xmax-xmin+1, ymax-ymin+1, 3))    # random sample in [0,1]
        mask[xmin:xmax+1, ymin:ymax+1, :] = noise    # noise mask with location
        mask_meta = [xmin, xmax, ymin, ymax, noise, mask]
#         mask = np.reshape(mask, [-1])
    return mask

def noise_batch(batch_num):
    # make random noise batch
    mask_batch = np.zeros([batch_num, FLAGS.img_size, FLAGS.img_size, 3])
    for i in range(batch_num):
        mask_batch[i] = noise_mask()
    return mask_batch


def occl(target, disturb):
    # Occlusion generation
    mask = (disturb==0).astype(float)
    masked_target = np.multiply(target, mask)
    crpt = np.add(masked_target, disturb)
    return crpt

### Nested Convolutional Encoder

In [None]:
def _nested_enc(l1, out_channel, name, filter_size=3, std=[1,2,2,1],pad='SAME', stddev=0.1):
    l1_shape = l1.get_shape()
                                   
    #with tf.device('/CPU:0'):
    with tf.variable_scope('level1'):
        with tf.variable_scope(name):
            l1_weights = tf.get_variable('weights', 
                                         [filter_size, filter_size, l1_shape[3], out_channel*FLAGS.l1_ratio], 
                                         tf.float32, 
                                         initializer=tf.random_normal_initializer(stddev=stddev))
            l1_biases = tf.get_variable('biases', 
                                        [out_channel*FLAGS.l1_ratio],
                                        tf.float32, 
                                        initializer=tf.random_normal_initializer(stddev=stddev))


    l1_conv = tf.nn.conv2d(l1, l1_weights, strides=std, padding=pad)
    l1_act = tf.nn.sigmoid(tf.add(l1_conv, l1_biases))
      
    return l1_act

def _nested_enc_init(_img, out_channel, name, filter_size=3, std=[1,2,2,1],pad='SAME', stddev=0.1):
    # input is the input image 
    
    #with tf.device('/CPU:0'):
    with tf.variable_scope('level1'):
        with tf.variable_scope(name):
            l1_weights = tf.get_variable('weights', 
                                         [filter_size, filter_size, 3, out_channel*FLAGS.l1_ratio], 
                                         tf.float32, 
                                         initializer=tf.random_normal_initializer(stddev=stddev))
            l1_biases = tf.get_variable('biases', 
                                        [out_channel*FLAGS.l1_ratio],
                                        tf.float32, 
                                        initializer=tf.random_normal_initializer(stddev=stddev))

#     _input_img = tf.reshape(_img, [-1, FLAGS.img_size, FLAGS.img_size, 3])
    l1_conv = tf.nn.conv2d(_img, l1_weights, strides=std, padding=pad)
    l1_act = tf.nn.sigmoid(tf.add(l1_conv, l1_biases))
    
    return l1_act

def _nested_enc_last(l1, out_channel, name, pad='VALID', stddev=0.1):
    # output is an encoded vector -> Fully convolutioanl layer
    l1_shape = l1.get_shape()    
    #l1_size = [l1_shape[1], l2_shape[2]]    # l1_shape[1], l1_shape[2] will be used in conv_transpose shape

    #with tf.device('/CPU:0'):
    with tf.variable_scope('level1'):
        with tf.variable_scope(name):
            l1_weights = tf.get_variable('weights', 
                                         [l1_shape[1],l1_shape[2],l1_shape[3], out_channel*FLAGS.l1_ratio], 
                                         tf.float32, 
                                         initializer=tf.random_normal_initializer(stddev=stddev))
            l1_biases = tf.get_variable('biases', 
                                        [out_channel*FLAGS.l1_ratio],
                                        tf.float32, 
                                        initializer=tf.random_normal_initializer(stddev=stddev))

    l1_conv = tf.nn.conv2d(l1, l1_weights, strides=[1,1,1,1], padding=pad)
    l1_act = tf.nn.sigmoid(tf.add(l1_conv, l1_biases))
        
    return l1_act

### Nested Convolutional Decoder

In [None]:
def _nested_dec(l1,out_size, out_channel, name, filter_size=3, std=[1,2,2,1], pad='SAME', stddev=0.1):
    l1_shape = l1.get_shape()
    
    l1_out_shape = [tf.shape(l1)[0],out_size[0],out_size[1],int(out_channel*FLAGS.l1_ratio)]
    
    #with tf.device('/CPU:0'):
    with tf.variable_scope('level1'):
        with tf.variable_scope(name):
            l1_weights = tf.get_variable('weights', 
                                         [filter_size, filter_size, out_channel*FLAGS.l1_ratio, l1_shape[3]], 
                                         tf.float32, 
                                         initializer=tf.random_normal_initializer(stddev=stddev))
            l1_biases = tf.get_variable('biases', 
                                        [out_channel*FLAGS.l1_ratio],
                                        tf.float32, 
                                        initializer=tf.random_normal_initializer(stddev=stddev))

    l1_dec = tf.nn.conv2d_transpose(l1, l1_weights, 
                                    output_shape=l1_out_shape, strides=std, padding=pad)
    l1_act = tf.nn.sigmoid(tf.add(l1_dec, l1_biases))
        
    return l1_act
    
def _nested_dec_init(_l1, out_size, out_channel, name, pad='VALID', stddev=0.1):
    # input is an encoded vector
    # out_size is the shape for conv_transpose 
    l1_shape = _l1.get_shape()
    l1_out_shape = [tf.shape(_l1)[0],out_size[0],out_size[1],int(out_channel*FLAGS.l1_ratio)]
    
    #with tf.device('/CPU:0'):
    with tf.variable_scope('level1'):
        with tf.variable_scope(name):
            l1_weights = tf.get_variable('weights', 
                                         [out_size[0], out_size[1], out_channel*FLAGS.l1_ratio, l1_shape[3]], 
                                         tf.float32, 
                                         initializer=tf.random_normal_initializer(stddev=stddev))
            l1_biases = tf.get_variable('biases', 
                                        [out_channel*FLAGS.l1_ratio],
                                        tf.float32, 
                                        initializer=tf.random_normal_initializer(stddev=stddev))

#     l1 = tf.reshape(_l1, [-1, 1, 1, l1_shape[1]])
#     l2_s = tf.reshape(_l2_s, [-1, 1, 1, l2_s_shape[1]])
    l1_dec = tf.nn.conv2d_transpose(_l1, l1_weights, 
                                    output_shape=l1_out_shape, strides=[1,1,1,1], padding=pad)
    l1_act = tf.nn.sigmoid(tf.add(l1_dec, l1_biases))
    return l1_act
 
    
def _nested_dec_last(l1, l2_s, l2, name, filter_size=3, std=[1,2,2,1], pad='SAME', stddev=0.1):
    # output is original size image
    l1_shape = l1.get_shape()
    
    _out_shape = [tf.shape(l1)[0],FLAGS.img_size,FLAGS.img_size,3]

    #with tf.device('/CPU:0'):
    with tf.variable_scope('level1'):
        with tf.variable_scope(name):
            l1_weights = tf.get_variable('weights', 
                                         [filter_size, filter_size, 3, l1_shape[3]], 
                                         tf.float32, 
                                         initializer=tf.random_normal_initializer(stddev=stddev))
            out_biases = tf.get_variable('biases', 
                                        [3],
                                        tf.float32, 
                                        initializer=tf.random_normal_initializer(stddev=stddev))

    l1_dec = tf.nn.conv2d_transpose(l1, l1_weights, 
                                  output_shape=_out_shape, strides=std, padding=pad)
    l1_act = tf.nn.sigmoid(tf.add(l1_dec, out_biases))
    return l1_act

### Graph setup

In [None]:
# Network Topology
# n_input = FLAGS.img_size*FLAGS.img_size*3
n_enc1 = 64
n_enc2 = 128
n_enc3 = 256
n_dec1 = 128
n_dec2 = 64

# Inputs and Outputs
ph_pure = tf.placeholder("float", [None, FLAGS.img_size, FLAGS.img_size, 3])    # pure image --- core
ph_crpt = tf.placeholder("float", [None, FLAGS.img_size, FLAGS.img_size, 3])    # corrupted image   --- level2


# Model
def nested_ae_conv(_X):
    l1_enc1 = _nested_enc_init(_X, n_enc1, name='enc1')    # 32->16
    l1_enc2 = _nested_enc(l1_enc1, n_enc2, name='enc2')    # 16->8
    l1_enc3 = _nested_enc_last(l1_enc2, n_enc3, name='enc3')    # 8->4
    l1_dec1 = _nested_dec_init(l1_enc3, [8,8], n_dec1, name='dec1')    # 4->8
    l1_dec2 = _nested_dec(l1_dec1, [16,16], n_dec2, name='dec2')    # 8->16
    l1_out = _nested_dec_last(l1_dec2, name='out')    #16->32
    return l1_out

# Generation
core_gen = nested_ae_conv(ph_crpt)   # [None, n_input]

# Loss & Optimizer
with tf.name_scope("loss") as scope:
    loss = tf.reduce_mean(tf.nn.l2_loss(core_gen-ph_pure))
    _train_loss = tf.summary.scalar("train_loss", loss)
    _test_loss = tf.summary.scalar("test_loss", loss)

optm = tf.train.AdamOptimizer(learning_rate=FLAGS.lr).minimize(loss)


print("Graphs Ready")

### Initialize

In [None]:
merged = tf.summary.merge_all()
tensorboard_path = FLAGS.tboard
if not os.path.exists(tensorboard_path):
    os.makedirs(tensorboard_path)
writer = tf.summary.FileWriter(tensorboard_path)
init = tf.global_variables_initializer()

print("Initialize Ready")

### Data saving

In [None]:
outputdir = FLAGS.outputs
if not os.path.exists(outputdir+'/train'):
    os.makedirs(outputdir+'/train')

if not os.path.exists(outputdir+'/test'):
    os.makedirs(outputdir+'/test')
    
savedir = FLAGS.nets
if not os.path.exists(savedir):
    os.makedirs(savedir)
    
saver = tf.train.Saver(max_to_keep=FLAGS.save_max)
print("Saver ready")

### Run

In [None]:
#################################################
# Parameters
training_epochs = FLAGS.training_epochs
batch_num = FLAGS.batch_num
batch_size = FLAGS.batch_size
n_total_batch = int(FLAGS.img_num/batch_size)
display_step = FLAGS.display_step
#################################################
# Plot parameters
n_plot = 5    # plot 5 images
cifar10_train_img = read_cifar(FLAGS.test_dir+'/data_batch_1')     # (10000, 32, 32, 3)
cifar10_test_img = read_cifar(FLAGS.test_dir+'/test_batch')     # (10000, 32, 32, 3)
train_disp_idx = np.random.randint(FLAGS.img_num, size=n_plot)    # fixed during train time
train_gt_pure = np.copy(np.take(cifar10_train_img, train_disp_idx, axis=0))    # (n_plot, 32, 32, 3) fixed
test_disp_idx = np.random.randint(FLAGS.img_num, size=n_plot)
test_gt_pure = np.copy(np.take(cifar10_test_img, test_disp_idx, axis=0))    # (n_plot, 32, 32, 3) fixed

rand_train_idx = np.arange(FLAGS.img_num)    # for display loss
rand_test_idx = np.arange(FLAGS.img_num)    # for display loss

##################################################
# Initialize
sess = tf.Session(config=config)
sess.run(init)

cifar10_test_img = read_cifar(FLAGS.test_dir+'/test_batch')     # (10000, 32, 32, 3)

#################################################
# Optimize
start_optm = time.time()
for epoch in range(training_epochs):
    for cifar_batch_idx in range(FLAGS.batch_num):
        with tf.device('/CPU:0'):
            start_epoch = time.time()
            cifar_batch_name = FLAGS.train_dir+'/data_batch_%d' %(cifar_batch_idx+1)
            cifar10_img = read_cifar(cifar_batch_name)     # (10000, 32, 32, 3)
             
        np.random.seed(epoch)
        np.random.shuffle(rand_train_idx)
        np.random.shuffle(rand_test_idx)

        ##################################################
        # Iteration
        for batch_idx in range(n_total_batch):
#             with tf.device('/CPU:0'):
            batch_pure = np.take(cifar10_img, rand_train_idx[batch_size*batch_idx:batch_size*(batch_idx+1)], axis=0)   # pure image
            noise = noise_batch(batch_size)    # random noise
            batch_crpt = occl(batch_pure, noise)   # corrupted image 
            train_feeds = {ph_pure: batch_pure, ph_crpt: batch_crpt}
            sess.run(optm, feed_dict=train_feeds)

#         with tf.device('/CPU:0'):
        train_loss, tb_train_loss = sess.run([loss,_train_loss], feed_dict=train_feeds)

        test_pure = np.take(cifar10_test_img,rand_test_idx[:batch_size], axis=0)    # pure image
        test_noise = noise_batch(batch_size)    # random noise
        test_crpt = occl(test_pure,test_noise)   # corrupted image
        test_feeds = {ph_pure: test_pure, ph_crpt: test_crpt}
        test_loss, tb_test_loss = sess.run([loss,_test_loss], feed_dict=test_feeds)

        writer.add_summary(tb_train_loss, epoch)
        writer.add_summary(tb_test_loss, epoch)
        
        epoch_time = time.time() - start_epoch
        current_time = time.time() - start_optm
        print("Epoch : %03d/%03d data_batch_%d,  Train_loss : %.4f  Test_loss : %.4f, Time/batch_file : %.4f, Training time: %.4f" 
              % (epoch+1, training_epochs, cifar_batch_idx+1, train_loss, test_loss, epoch_time, current_time))   
        
    # Display
    if (epoch+1) % display_step == 0:
        # train_gt_pure  # pure image
        train_gt_noise = noise_batch(n_plot)    # random noise
        train_gt_crpt = occl(train_gt_pure,train_gt_noise)   # corrupted image
        train_gt_feeds = {ph_pure: train_gt_pure,ph_crpt: train_gt_crpt}
        
        # test_gt_pure   # pure image
        test_gt_noise = noise_batch(n_plot)    # random noise
        test_gt_crpt = occl(test_gt_pure,test_gt_noise)   # corrupted image
        test_gt_feeds = {ph_pure: test_gt_pure, ph_crpt: test_gt_crpt}
        
        ##########################################################
        # generated images
        train_gen_pure = sess.run(core_gen, feed_dict=train_gt_feeds)  # 3072-d vector
        test_gen_pure = sess.run(core_gen,feed_dict=test_gt_feeds)  # 3072-d vector
        
        ##########################################################
        # plotting results from training data

        fig, axes = plt.subplots(nrows=2, ncols=n_plot, figsize=(10,n_plot))   # displaying 4*n_plot images
        plt.setp(axes, xticks=np.arange(0,31,8), yticks=np.arange(0,31,8)) 
        for j in range(n_plot):
#                 train_disp_gt_crpt = np.reshape(train_gt_crpt[j], [FLAGS.img_size,FLAGS.img_size, 3])    # 28x28
            axes[0, j].imshow(train_gt_crpt[j], cmap='gray')   
            axes[0, j].set(ylabel='gt_crpt')
            axes[0, j].label_outer()

#                 train_disp_gen_pure = np.reshape(train_gen_pure[j], [FLAGS.img_size,FLAGS.img_size, 3])    # 28x28
            axes[1, j].imshow(train_gen_pure[j], cmap='gray')   
            axes[1, j].set(ylabel='gen_pure')
            axes[1, j].label_outer()


        plt.savefig(outputdir+'/train/epoch %03d' %(epoch+1))    
        plt.close(fig)

        # plotting results from testing data
        fig, axes = plt.subplots(nrows=2, ncols=n_plot, figsize=(10,n_plot))   # displaying 4*n_plot images
        plt.setp(axes, xticks=np.arange(0,31,8), yticks=np.arange(0,31,8)) 
        for k in range(n_plot):
#                 test_disp_gt_crpt = np.reshape(test_gt_crpt[k], [FLAGS.img_size,FLAGS.img_size, 3])    # 28x28
            axes[0, k].imshow(test_gt_crpt[k])   
            axes[0, k].set(ylabel='gt_crpt')
            axes[0, k].label_outer()

#                 test_disp_gen_pure = np.reshape(test_gen_pure[k], [FLAGS.img_size,FLAGS.img_size, 3])    # 28x28
            axes[1, k].imshow(test_gen_pure[k])   
            axes[1, k].set(ylabel='gen_pure')
            axes[1, k].label_outer()           


        plt.savefig(outputdir+'/test/epoch %03d' %(epoch+1))    
        plt.close(fig)

        # Save
        if (epoch+1) % FLAGS.save_step ==0:
            savename = savedir+"/net-"+str(epoch+1)+".ckpt"
            saver.save(sess, savename)
            print("[%s] SAVED" % (savename))

print("Optimization Finished")

### Restore

In [None]:
do_restore = 1
if do_restore == 1:
    sess = tf.Session()
    epoch = FLAGS.training_epochs
    savename = savedir+"/net-"+str(epoch)+".ckpt"
    saver.restore(sess, savename)
    print ("NETWORK RESTORED")
else:
    print ("DO NOTHING")

### Test

In [None]:
test_disp_idx = np.random.randint(FLAGS.img_num, size=n_plot)
test_gt_pure = np.copy(np.take(cifar10_test_img, test_disp_idx, axis=0))    # (n_plot, 3072) fixed
test_gt_noise = noise_batch(5)    # random noise
test_gt_crpt = occl(test_gt_pure,test_gt_noise)   # corrupted image
test_gt_feeds = {ph_crpt: test_gt_crpt}
test_gen_pure= sess.run(core_gen, shell2_gen, full_gen,feed_dict=test_gt_feeds)

# plotting results from testing data
fig, axes = plt.subplots(nrows=2, ncols=n_plot, figsize=(10,n_plot))   # displaying 4*n_plot images
plt.setp(axes, xticks=np.arange(0,31,8), yticks=np.arange(0,31,8)) 
for k in range(n_plot):
#     test_disp_gt_crpt = np.reshape(test_gt_crpt[k], [FLAGS.img_size,FLAGS.img_size, 3])    # 28x28
    axes[0, k].imshow(test_gt_crpt[k])   
    axes[0, k].set(ylabel='gt_crpt')
    axes[0, k].label_outer()

#     test_disp_gen_pure = np.reshape(test_gen_pure[k], [FLAGS.img_size,FLAGS.img_size, 3])    # 28x28
    axes[1, k].imshow(test_gen_pure[k])   
    axes[1, k].set(ylabel='gen_pure')
    axes[1, k].label_outer()           
