In [1]:
import os
import tqdm
import re
import matplotlib.pyplot as plt
import numpy as np
import time
import warnings
from datetime import timedelta
import logging
logging.getLogger('tensorflow').disabled = True

import tensorflow as tf
from utils import *
from model import *
import argparse

tf.get_logger().setLevel('ERROR')


In [2]:
BATCH_SIZE = 32
NUM_CLASSES = 1
EPOCHS = 10
GENERATION_RATE = 2
LEARNING_RATE = 0.0002
truncated_uniform_scale_flag = True

IMG_WIDTH, IMG_HEIGHT = 128, 128
#BETA_1, BETA_2 = 0.5, 0.999
BETA_1, BETA_2 = 0.5, 0.999
thres_int = 0.5


In [3]:
model_name = None

In [4]:
""" saving paths """

if model_name is None:
    model_name = time.strftime('%Y-%m-%d_%H:%M:%S_%z') + "_" + str(BATCH_SIZE)
    output_dir = "output"
    model_dir = '{}/{}'.format(output_dir, model_name)
    image_dir = '{}/images'.format(model_dir)
    checkpoints_dir = '{}/checkpoints'.format(model_dir)
    for path in [output_dir, model_dir, image_dir, checkpoints_dir]:
        if not os.path.exists(path):
            os.mkdir(path)
    print("created model")
else:
    model_name = model_name
    print("proceeding to load model: {}".format(model_name))

created model


In [5]:
""" tf session definitions """
tf.reset_default_graph()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 0.8
sess = tf.Session(config=config)

In [6]:
""" load TFRecordDataset """
training_data = Barkley_Deep_Drive('../../self-driving-AttGAN/resources/train.tfrecords')
validation_data = Barkley_Deep_Drive('../../self-driving-AttGAN/resources/test.tfrecords')

train_iterator = training_data.get_batch(EPOCHS, BATCH_SIZE, shuffle = False)
val_iterator = validation_data.get_batch(EPOCHS, BATCH_SIZE, shuffle = False)

train_image_iterator, train_label_iterator = train_iterator.get_next()
val_image_iterator, val_label_iterator = val_iterator.get_next()

In [7]:
""" Placeholders """
xa = tf.placeholder(tf.float32,shape=[BATCH_SIZE,IMG_WIDTH,IMG_HEIGHT,3],name="xa") #orignal image
z = encoder(xa, reuse=tf.compat.v1.AUTO_REUSE ) #encoder output

a = tf.placeholder(tf.float32, shape=[BATCH_SIZE, NUM_CLASSES],name="a") #original attributes
b = tf.placeholder(tf.float32, shape=[BATCH_SIZE, NUM_CLASSES],name="b") #desired attributes

xb_hat = decoder(z, b, reuse=tf.compat.v1.AUTO_REUSE ) #decoder output
with tf.control_dependencies([xb_hat]):
    xa_hat = decoder(z, a, reuse=tf.compat.v1.AUTO_REUSE ) #decoder output

xa_logit_D, xa_logit_C = classifier_and_discriminator(xa, reuse=tf.compat.v1.AUTO_REUSE,  NUM_CLASSES=1)
xb_logit_D, xb_logit_C = classifier_and_discriminator(xb_hat, reuse=tf.compat.v1.AUTO_REUSE,  NUM_CLASSES=1)


In [8]:
""" penalty """
lambda_ = {"3" : 1, "2" : 10, "1" : 100}

""" interpolated image noise"""
#epsilon = tf.random_uniform(shape=[BATCH_SIZE, 1, 1, 1], minval=0., maxval=1.)
#interpolated_image = xa + epsilon * (xb_hat - xa)
#c_interpolated, d_interpolated = classifier_and_discriminator(interpolated_image, reuse=tf.compat.v1.AUTO_REUSE, NUM_CLASSES=1)

# Gradient penalty
alpha = tf.random_uniform(
    shape=[IMG_WIDTH,1], 
    minval=0.,
    maxval=1.
)
differences =  xb_hat - xa
interpolates = xa + (alpha*differences)
gradients = tf.gradients(classifier_and_discriminator(interpolates, reuse=tf.AUTO_REUSE), [interpolates])[0]
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
gp = tf.reduce_mean((slopes-1.)**2)

""" D loss """
loss_adv_D =  - ( tf.reduce_mean(xa_logit_D) - tf.reduce_mean(xb_logit_D) )
#grad_d_interpolated = tf.gradients(d_interpolated, [interpolated_image])[0]
#slopes = tf.sqrt(1e-10 + tf.reduce_sum(tf.square(grad_d_interpolated), axis=[1, 2, 3]))
#gp = tf.reduce_mean((slopes - 1.) ** 2)

loss_cls_C = tf.losses.sigmoid_cross_entropy(a, xa_logit_C)

D_loss = loss_adv_D + gp * lambda_['2'] + loss_cls_C

""" G loss """
loss_adv_G = -tf.reduce_mean(xb_logit_D)
loss_cls_G = tf.losses.sigmoid_cross_entropy(b, xb_logit_C)
loss_rec = tf.losses.absolute_difference(xa, xa_hat)

G_loss =  loss_adv_G + lambda_['2'] * loss_cls_G + lambda_['1'] * loss_rec


In [9]:
""" Training """
# divide trainable variables into a group for D and a group for G
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if 'C_D' in var.name ]
g_vars = [var for var in t_vars if 'G_' in var.name]
assert(len(t_vars) == len(d_vars ) + len(g_vars )), "mismatch in variable names"

d_optim = tf.train.AdamOptimizer(learning_rate = LEARNING_RATE,
                                 beta1 = BETA_1,
                                 beta2 = BETA_2).minimize(D_loss, var_list=d_vars)

g_optim = tf.train.AdamOptimizer(learning_rate = LEARNING_RATE,
                                 beta1 = BETA_1,
                                 beta2 = BETA_2).minimize(G_loss, var_list=g_vars)

""" Summary """
d_summary = summary({
    loss_adv_D: 'loss_adv_D',
    gp: 'gp',
    loss_cls_C: 'loss_cls_C',
}, scope='D_')

g_summary = summary({
    loss_adv_G: 'loss_adv_G',
    loss_cls_G: 'loss_cls_G',
    loss_rec: 'loss_rec',
}, scope='G_')

""" init """
summary_writer = tf.summary.FileWriter('./graphs', sess.graph)
saver = tf.train.Saver()
S_ = tf.Summary()

""" checkpoints load """
try:
    load_checkpoint(checkpoints_dir, sess, t_vars)
except:
    print("did not load checkpoints")
    sess.run(tf.global_variables_initializer())


did not load checkpoints


In [10]:
""" mapping defintions """
d_loss_epoch, g_loss_epoch = [], []
if NUM_CLASSES == 2:
    label_mapping = {'daytime': [1, 0], 'night': [0, 1]}
else:
    label_mapping = {'daytime': [1], 'night': [0]}
    flip = {'1':[0], '0': [1] }


In [None]:
""" main training loop """

for epoch_no in tqdm.tqdm(range(EPOCHS), total=EPOCHS):

    sess.run(val_iterator.initializer)
    sess.run(train_iterator.initializer)

    try:
        step = 0
        start_time = time.monotonic()
        d_loss_per_batch = []
        g_loss_per_batch = []
        while True:
            # Sample batch from dataset
            image_batch, label_batch = sess.run([val_image_iterator, val_label_iterator])
            #image_batch, label_batch = sess.run([train_image_iterator, train_label_iterator])

            # Transform label batch in our simple one hot encoded version
            a_label_batch = np.array([label_mapping[label.decode("utf-8")] for label in label_batch], dtype=np.float32)
            if truncated_uniform_scale_flag:
                b_label_batch = a_label_batch.copy().astype(np.float32)
                np.random.shuffle(b_label_batch)
                a_label_batch = (a_label_batch * 2 - 1) * thres_int
                b_label_batch = (b_label_batch * 2 - 1) * (np.random.uniform(size=b_label_batch.shape) + 2) / 4.0 * (2 * thres_int)
            else:
                b_label_batch = [flip[str(int(label))] for label in a_label_batch]
                a_label_batch = np.asarray(a_label_batch, dtype=np.float32)
                b_label_batch = np.asarray(b_label_batch, dtype=np.float32)

            # Optimize
            d_summary_opt, _, D_loss_val = sess.run([d_summary, d_optim, D_loss], feed_dict={xa:image_batch,
                                                                    a: a_label_batch, b: b_label_batch})
            if step % 5 == 0:
                g_summary_opt, _, G_loss_val = sess.run([g_summary, g_optim, G_loss], feed_dict={xa:image_batch,
                                                                    a: a_label_batch, b: b_label_batch})

            d_loss_per_batch.append(D_loss_val)
            g_loss_per_batch.append(G_loss_val)

            summary_writer.add_summary(d_summary_opt, epoch_no)
            summary_writer.add_summary(g_summary_opt, epoch_no)

            if step % 400 == 0:
                print("At step ", step, "we have")
                print("Gen loss: ",np.mean(g_loss_per_batch), " and Desc loss:", np.mean(d_loss_per_batch) , "\n ")
                S_.ParseFromString(d_summary_opt)
                print(S_)
                S_.ParseFromString(g_summary_opt)
                print(S_)

            step += 1
    except tf.errors.OutOfRangeError:
            checkpoint_save_path = saver.save(sess, '{}/Epoch_{}_{}.ckpt'.format(checkpoints_dir, str(epoch_no), str(step)))
            print('Model is saved at {}!'.format(checkpoint_save_path))

            # Generating reconstructed image xa_hat and flipped attribute image xb_hat
            #image_batch, label_batch = sess.run([val_image_iterator, val_label_iterator])
            image_batch, label_batch = sess.run([train_image_iterator, train_label_iterator])

            # Transform label batch in our simple one hot encoded version
            a_label_batch = np.array([label_mapping[label.decode("utf-8")] for label in label_batch])#, dtype=np.float64)
            if truncated_uniform_scale_flag:
                b_label_batch = tf.random_shuffle(a_label_batch)
                a_label_batch = (a_label_batch * 2 - 1) * thres_int
                b_label_batch = (tf.to_float(b_label_batch) * 2 - 1) * (tf.truncated_normal(tf.shape(b_label_batch)) + 2) / 4.0 * (2 * thres_int)
                b_label_batch = sess.run(b_label_batch)
            else:
                b_label_batch = [flip[str(int(label))] for label in a_label_batch]
                a_label_batch = np.asarray(a_label_batch, dtype=np.float32)
                b_label_batch = np.asarray(b_label_batch, dtype=np.float32)

            step_xb_hat = sess.run(xb_hat, feed_dict={a:a_label_batch, b:b_label_batch, xa:image_batch})
            step_xa_hat = sess.run(xa_hat, feed_dict={a:a_label_batch, b:b_label_batch, xa:image_batch})
            
            """image saving"""
            output_path = os.path.join(image_dir, "epoch_no_"+ str(epoch_no) +"_" +".png")
            plot_block_after_epoch(output_path, label_batch, image_batch, step_xa_hat, step_xb_hat,
                                   examples = 3, plot = True)
            

            end_time = time.monotonic()
            print("END OF EPOCH")
            fmt = "Epoch duration: {}".format(timedelta(seconds=end_time - start_time))
            print(fmt)
            d_loss_epoch.append(np.mean(d_loss_per_batch))
            g_loss_epoch.append(np.mean(g_loss_per_batch))
            print("Discriminator loss: ", d_loss_epoch[-1])
            print("Generator loss: ", g_loss_epoch[-1])
            print("-"*len(fmt))
            pass


checkpoint_save_path = saver.save(sess, '{}/Epoch_{}_{}.ckpt'.format(checkpoints_dir, str(epoch_no), str(step)))
print('Finished training\nModel has been saved at {}!'.format(checkpoint_save_path))
sess.close()

try:
    to_json = {"d_loss": d_loss_epoch,
           "g_loss": g_loss_epoch}
    with open(model_name+'.json', 'w') as f:
        json.dump(to_json, f)
except:
    print(d_loss_epoch)
    print(g_loss_epoch)

  0%|          | 0/10 [00:00<?, ?it/s]

At step  0 we have
Gen loss:  51.41764  and Desc loss: 7.2865767 
 
value {
  tag: "D_/loss_adv_D"
  simple_value: -0.01121532917022705
}
value {
  tag: "D_/gp"
  simple_value: 0.6553004384040833
}
value {
  tag: "D_/loss_cls_C"
  simple_value: 0.7447875142097473
}

value {
  tag: "G_/loss_adv_G"
  simple_value: 6.498315811157227
}
value {
  tag: "G_/loss_cls_G"
  simple_value: 0.7086101174354553
}
value {
  tag: "G_/loss_rec"
  simple_value: 0.378332257270813
}

At step  400 we have
Gen loss:  27.456  and Desc loss: 6.060248 
 
value {
  tag: "D_/loss_adv_D"
  simple_value: 0.486125111579895
}
value {
  tag: "D_/gp"
  simple_value: 0.9373887181282043
}
value {
  tag: "D_/loss_cls_C"
  simple_value: 0.8757617473602295
}

value {
  tag: "G_/loss_adv_G"
  simple_value: 0.8416420221328735
}
value {
  tag: "G_/loss_cls_G"
  simple_value: 0.7507617473602295
}
value {
  tag: "G_/loss_rec"
  simple_value: 0.07500743865966797
}

At step  800 we have
Gen loss:  24.174932  and Desc loss: 2.46799

 10%|█         | 1/10 [05:56<53:27, 356.36s/it]

END OF EPOCH
Epoch duration: 0:05:56.059150
Discriminator loss:  -56.918068
Generator loss:  53.121864
------------------------------
At step  0 we have
Gen loss:  176.04536  and Desc loss: -101.17202 
 
value {
  tag: "D_/loss_adv_D"
  simple_value: -109.17130279541016
}
value {
  tag: "D_/gp"
  simple_value: 0.7306140065193176
}
value {
  tag: "D_/loss_cls_C"
  simple_value: 0.6931471824645996
}

value {
  tag: "G_/loss_adv_G"
  simple_value: 153.3275146484375
}
value {
  tag: "G_/loss_cls_G"
  simple_value: 0.6931471824645996
}
value {
  tag: "G_/loss_rec"
  simple_value: 0.1578637808561325
}

At step  400 we have
Gen loss:  -82.000854  and Desc loss: -10.937937 
 
value {
  tag: "D_/loss_adv_D"
  simple_value: 3.6623382568359375
}
value {
  tag: "D_/gp"
  simple_value: 0.8960720896720886
}
value {
  tag: "D_/loss_cls_C"
  simple_value: 0.648134708404541
}

value {
  tag: "G_/loss_adv_G"
  simple_value: -124.4700927734375
}
value {
  tag: "G_/loss_cls_G"
  simple_value: 0.7900403738