In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

%load_ext autoreload
%aimport util
from util import *
%autoreload 1

from PIL import Image
from IPython.display import display

In [2]:
delighted_dirs, lighted_dirs = scan_lighted_delighted('data/')
delighted_data, lighted_data = load_dataset(delighted_dirs, lighted_dirs)
test_meshes = ['Mesh_000003', 'Mesh_000006']
train_lighted_data = [(label, arr) for label, arr in lighted_data if label not in test_meshes]
test_lighted_data = [(label, arr) for label, arr in lighted_data if label in test_meshes]
np.random.seed(0)
np.random.shuffle(train_lighted_data)

In [3]:
GEN_ENCODER = 'gen_encoder'
GEN_MID = 'gen_mid'
GEN_DECODER = 'gen_decoder'
DISC_CONV = 'disc_conv'
DISC_FC = 'disc_fc'

def get_dimension_name(stage, layer_num):
    return 'dims_' + stage + '_' + str(layer_num)

def get_weight_name(stage, layer_num):
    return 'w_' + stage + '_' + str(layer_num)
    
def get_bias_name(stage, layer_num):
    return 'b_' + stage + '_' + str(layer_num)

FULL_HEIGHT = 1024
TRAIN_HEIGHT = 32
TEST_HEIGHT = FULL_HEIGHT

train_input_dimensions = (TRAIN_HEIGHT, TRAIN_HEIGHT, 3)
test_input_dimensions = (TEST_HEIGHT, TEST_HEIGHT, 3)

In [4]:
def fc(x, output_depth, name, activation=tf.nn.relu):
    input_depth = int(x.get_shape()[-1])
    with tf.variable_scope(name):
        W = tf.get_variable('W', shape=[input_depth, output_depth], initializer=tf.contrib.layers.xavier_initializer())
        b = tf.get_variable('b', [output_depth], initializer=tf.zeros_initializer())
        return activation(tf.matmul(x, W) + b)
    
def conv2d(x, output_depth, name, filter_size=3, stride=1, padding='SAME'):
    input_depth = int(x.get_shape()[-1])

    with tf.variable_scope(name):
        W = tf.get_variable('W',
                            shape=[filter_size, filter_size, input_depth, output_depth],
                            initializer=tf.contrib.layers.xavier_initializer())
        b = tf.get_variable('b', shape=[output_depth], initializer=tf.zeros_initializer())
        conv_output = tf.nn.conv2d(x, W, strides=[1, stride, stride, 1], padding=padding) + b
        return tf.nn.relu(conv_output)

def deconv2d(x, output_depth, name, filter_size=3, stride=1, padding='SAME'):
    _, old_height, old_width, input_depth = x.get_shape().as_list()
    x = conv2d(x, output_depth, name, filter_size=filter_size, stride=stride, padding=padding)
    with tf.variable_scope(name):
        # bilinear interpolation upsampling
        scale = 2
        new_height = old_height * scale
        new_width = old_width * scale
        return tf.image.resize_images(x, [new_height, new_width], method=tf.image.ResizeMethod.BILINEAR)

def gradient_difference_loss(expected, predicted, alpha):
    pos = tf.constant(np.identity(3), dtype=tf.float32)
    filter_x = tf.expand_dims(tf.stack([-pos, pos]), 0)  # [-1, 1]
    filter_y = tf.stack([tf.expand_dims(pos, 0), tf.expand_dims(-pos, 0)])  # [[1], [-1]]

    predicted_dx = tf.abs(tf.nn.conv2d(predicted, filter_x, [1, 1, 1, 1], padding='SAME'))
    predicted_dy = tf.abs(tf.nn.conv2d(predicted, filter_y, [1, 1, 1, 1], padding='SAME'))
    expected_dx = tf.abs(tf.nn.conv2d(expected, filter_x, [1, 1, 1, 1], padding='SAME'))
    expected_dy = tf.abs(tf.nn.conv2d(expected, filter_y, [1, 1, 1, 1], padding='SAME'))

    grad_diff_x = tf.abs(expected_dx - predicted_dx)
    grad_diff_y = tf.abs(expected_dy - predicted_dy)

    return tf.reduce_mean(grad_diff_x ** alpha + grad_diff_y ** alpha)

In [5]:
class GeneratorNetwork(object):
    def __init__(self, session, train_input_dimensions, test_input_dimensions,
                 hyperparameters=None,
                 input_residual=False,
                 alpha_mask_loss=False):
        self.sess = session
        self.train_input_dimensions = train_input_dimensions
        self.test_input_dimensions = test_input_dimensions
        if not hyperparameters:
            hyperparameters = {
                'learning_rate' : 1e-3,
                'gdl_alpha' : 1,
                'lambda_l2' : 1,
                'lambda_gdl' : 1,
                'lambda_adv' : 1
            }
        self.hyperparams = hyperparameters
        self.train_variables = []
        self.has_defined_layers = False
        self.input_residual = input_residual
        self.alpha_mask_loss = alpha_mask_loss
    
    def init_network(self, discriminator):
        train_width, train_height, train_depth = self.train_input_dimensions
        self.train_input = tf.placeholder(tf.float32, shape=[None, train_width, train_height, train_depth])
        self.alpha_mask = tf.placeholder(tf.float32, shape=[None, train_width, train_height, 1])
        self.test_input = tf.placeholder(tf.float32, shape=[None,] + list(self.test_input_dimensions))
        self.expected_output = tf.placeholder(tf.float32, shape=[None,] + list(self.train_input_dimensions))
        
        train_output = self.get_output_tensor(self.train_input)
        self.test_output = self.get_output_tensor(self.test_input)
        
        sq_diff = tf.squared_difference(train_output, self.expected_output)
        if self.alpha_mask_loss:
            sq_diff *= self.alpha_mask
        l2_loss = tf.reduce_mean(sq_diff)
        gdl_loss = gradient_difference_loss(train_output, self.expected_output, self.hyperparams['gdl_alpha'])
        adv_loss = -tf.reduce_mean(tf.log(discriminator.get_output_tensor(train_output)))
        self.loss = self.hyperparams['lambda_l2'] * l2_loss \
                  + self.hyperparams['lambda_gdl'] * gdl_loss \
                  + self.hyperparams['lambda_adv'] * adv_loss
        self.opt = tf.train.AdamOptimizer(learning_rate=self.hyperparams['learning_rate']).minimize(self.loss, var_list=self.train_variables)
        with tf.name_scope('generator'):
            l2_loss_summ = tf.summary.scalar('l2_loss', l2_loss)
            gdl_loss_summ = tf.summary.scalar('gradient_difference_loss', gdl_loss)
            adversarial_loss_summ = tf.summary.scalar('adversarial_loss', adv_loss)
            loss_summ = tf.summary.scalar('loss', self.loss)
            self.summaries = tf.summary.merge([l2_loss_summ, gdl_loss_summ, adversarial_loss_summ, loss_summ])
        
    def get_output_tensor(self, input):
        with tf.variable_scope('generator', reuse=self.has_defined_layers):
            encoder0 = conv2d(input, 4, 'encoder_0')
            encoder1 = conv2d(encoder0, 8, 'encoder_1', stride=2)
            encoder2 = conv2d(encoder1, 16, 'encoder_2', stride=2)
            mid0 = conv2d(encoder2, 16, 'mid_0')
            decoder0 = deconv2d(tf.concat([mid0, encoder2], axis=3), 8, 'decoder_0')
            decoder1 = deconv2d(tf.concat([decoder0, encoder1], axis=3), 4, 'decoder_1')
            if self.input_residual:
                output = input + conv2d(tf.concat([decoder1, encoder0], axis=3), 3, 'decoder_2')
            else:
                output = conv2d(decoder1, 3, 'decoder_2')
        if not self.has_defined_layers:
            self.train_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
            self.has_defined_layers = True
        return output
    
    def fit_batch(self, inputs, expected_outputs, alpha):
        _, loss, summaries = self.sess.run((self.opt, self.loss, self.summaries), feed_dict={self.train_input : inputs, self.expected_output : expected_outputs, self.alpha_mask : alpha })
        return loss, summaries
                
    def predict(self, inputs):
        return self.sess.run(self.test_output, feed_dict={self.test_input : inputs})
    
class DiscriminatorNetwork(object):
    def __init__(self, session, train_input_dimensions, learning_rate=1e-3):
        self.sess = session
        self.train_input_dimensions = train_input_dimensions
        self.learning_rate = learning_rate
        self.train_variables = []
        self.has_defined_layers = False
    
    def init_network(self, generator):
        self.lighted_input = tf.placeholder(tf.float32, shape=[None,] + list(self.train_input_dimensions))
        self.delighted_input = tf.placeholder(tf.float32, shape=[None,] + list(self.train_input_dimensions))
        real_input = self.delighted_input
        fake_input = generator.get_output_tensor(self.lighted_input)
        
        predicted_real = self.get_output_tensor(real_input)
        predicted_fake = self.get_output_tensor(fake_input)
        
        real_loss = -tf.reduce_mean(tf.log(predicted_real))
        fake_loss = -tf.reduce_mean(tf.log(1 - predicted_fake))
        
        self.loss = real_loss + fake_loss
        self.opt = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(self.loss, var_list=self.train_variables)
        with tf.name_scope('discriminator'):
            real_loss_summ = tf.summary.scalar('real_loss', real_loss)
            fake_loss_summ = tf.summary.scalar('fake_loss', fake_loss)
            loss_summ = tf.summary.scalar('loss', self.loss)
            self.summaries = tf.summary.merge([real_loss_summ, fake_loss_summ, loss_summ])
            
    def get_output_tensor(self, input):
        '''
        Given an input tensor / placeholder, perform convs then FCs to get the probability that input is real
        '''
        with tf.variable_scope('discriminator', reuse=self.has_defined_layers):
            conv0 = conv2d(input, 4, 'conv_0', stride=2)
            conv1 = conv2d(conv0, 8, 'conv_1', stride=2)
            conv2 = conv2d(conv1, 16, 'conv_2', stride=2)
            fc_input = tf.contrib.layers.flatten(conv2)
            fc_0 = fc(fc_input, 256, 'fc_0')
            fc_1 = fc(fc_0, 128, 'fc_1')
            output = fc(fc_1, 1, 'fc_2', activation=tf.nn.sigmoid)
        if not self.has_defined_layers:
            self.train_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
            self.has_defined_layers = True
        return output
        
    def fit_batch(self, lighted_inputs, delighted_inputs):
        _, loss, summaries = self.sess.run((self.opt, self.loss, self.summaries), feed_dict={ self.lighted_input : lighted_inputs, self.delighted_input : delighted_inputs })
        return loss, summaries

In [6]:
def get_train_batch(delighted_data, lighted_data, batch_size, i):
    inputs = []
    expected_outputs = []
    alphas = []
    for i, (mesh, input) in enumerate(lighted_data[i * batch_size : (i + 1) * batch_size]):
        while True:
            start_x, start_y = np.random.randint(0, FULL_HEIGHT - TRAIN_HEIGHT + 1, size=(2,))
            sample = input[start_x : start_x + TRAIN_HEIGHT, start_y : start_y + TRAIN_HEIGHT]
            if np.sum(sample > 0) > 0.1 * TRAIN_HEIGHT ** 2:
                rotate_rand = np.random.randint(4)
                flip_rand = np.random.randint(0, 1, size=1)[0]
                sample = np.rot90(sample, rotate_rand)
                delighted = delighted_data[mesh][start_x : start_x + TRAIN_HEIGHT, start_y : start_y + TRAIN_HEIGHT]
                if flip_rand:
                    sample = np.flip(sample, axis=0)
                    delighted = np.flip(delighted, axis=0)
                inputs.append(sample)
                expected_outputs.append(delighted)
                alphas.append(np.any(delighted, axis=2, keepdims=True).astype(int))
                break
    return np.asarray(expected_outputs), np.asarray(inputs), np.asarray(alphas)

def get_test_batch(delighted_data, lighted_data, batch_size, i):
    meshes, inputs = zip(*lighted_data[i * batch_size : (i + 1) * batch_size])
    inputs = np.asarray(inputs)
    expected_outputs = np.asarray([delighted_data[mesh] for mesh in meshes])
    return expected_outputs, inputs

def restore(sess, checkpoint_file):
    saver = tf.train.Saver()
    saver.restore(sess, checkpoint_file)
    
def save(sess, checkpoint_file):
    saver = tf.train.Saver()
    saver.save(sess, checkpoint_file)

In [7]:
tf.reset_default_graph()
sess = tf.Session()
gen_hyperparams = {
    'learning_rate' : 1e-4,
    'gdl_alpha' : 1,
    'lambda_l2' : 0.3,
    'lambda_gdl' : 0,
    'lambda_adv' : 5
}
generator = GeneratorNetwork(sess, train_input_dimensions, test_input_dimensions, gen_hyperparams,
                             input_residual=False,
                             alpha_mask_loss=True)
discriminator = DiscriminatorNetwork(sess, train_input_dimensions, learning_rate=1e-4)
generator.init_network(discriminator)
discriminator.init_network(generator)

sess.run(tf.global_variables_initializer())

In [None]:
epochs = 200
n_samples = len(train_lighted_data)
batch_size = 50
display_step = 1
iters_per_epoch = n_samples // batch_size
summary_interval = iters_per_epoch
start_epoch = 0

summary_writer = tf.summary.FileWriter('summaries/test5', graph=sess.graph)
mean_gen_losses = []
mean_disc_losses = []
for epoch in range(start_epoch, start_epoch + epochs):
    total_gen_loss = 0
    total_disc_loss = 0
    for i in range(iters_per_epoch):
        expected_outputs, inputs, alphas = get_train_batch(delighted_data, train_lighted_data, batch_size, i)
        gen_loss, gen_summaries = generator.fit_batch(inputs, expected_outputs, alphas)
        disc_loss, disc_summaries = discriminator.fit_batch(inputs, expected_outputs)
        if (i + 1) % summary_interval == 0:
            step = epoch * n_samples + (i + 1) * batch_size
            summary_writer.add_summary(gen_summaries, step)
            summary_writer.add_summary(disc_summaries, step)
        total_gen_loss += gen_loss
        total_disc_loss += disc_loss
    mean_gen_loss = total_gen_loss / iters_per_epoch
    mean_disc_loss = total_disc_loss / iters_per_epoch
    mean_gen_losses.append(mean_gen_loss)
    mean_disc_losses.append(mean_disc_loss)
    if (epoch + 1) % display_step == 0:
        print('epoch %s: gen_loss=%.4f, disc_loss=%.4f' % (epoch + 1, mean_gen_loss, mean_disc_loss))

In [None]:
label, img = test_lighted_data[2]
def show(img):
    plt.imshow(img)
    plt.axis('off')
    plt.show()
print('lighted')
show(img)
print('delighted')
show(delighted_data[label])
print('predicted')
show(float_to_uint8(generator.predict(np.array([img])))[0])