## IMPORTS

In [None]:
import skimage.io
import skimage.transform
import skimage.exposure
from PIL import ImageFile
import os
from glob import glob
import numpy as np
import pandas as pd
import ipdb
import tensorflow as tf
import pickle as cPickle
tf.compat.v1.disable_eager_execution()
import cv2
import matplotlib.pyplot as plt
#from tensorflow.python.keras.models 

1. Load Image 

In [None]:

def load_image(path, ori_height=146, ori_width=146, height=128, width=128 ):
    
    try:
        image = skimage.io.imread(path).astype(float)    
    except:
        return None

    image = image/255.

    if image is None: 
        return None
    if len(image.shape) < 2: 
        return None
    if len(image.shape) == 2: 
        image=np.tile(image[:,:,None], 3)
    if len(image.shape) == 4:
        return None
    if image.shape[2] == 4: 
        image=image[:,:,:3]
    if image.shape[2] > 4: 
        return None

    shortest_e = min( image.shape[:2] )
    y = int((image.shape[0] - shortest_e) / 2)
    x = int((image.shape[1] - shortest_e) / 2)
    crop_img = image[y:y+shortest_e, x:x+shortest_e]
    resize = skimage.transform.resize( crop_img, [ori_height,ori_width] )

    rand_y = np.random.randint(0, ori_height - height)
    rand_x = np.random.randint(0, ori_width - width)

    resize = resize[ rand_y:rand_y+height, rand_x:rand_x+width, : ]
    #get_resized_img = (resize*2)-1
    # get_resized_img= (resize + 1) / 2
    # get_resized_img= skimage.exposure.equalize_hist(resize)
    
    #plt.imshow(get_resized_img)

    return (resize * 2) -1


In [None]:
img = load_image('dl_project-1/svt1/train/01_16.jpg')
img = (img + 1) / 2
plt.imshow(img)
#plt.show()

2. Crop Image Function

In [None]:
def crop_central(image_ori, width=64,height=64, x=None, y=None, overlap=7):
    if image_ori is None: return None
    random_y = np.random.randint(overlap,height-overlap) if x is None else x
    random_x = np.random.randint(overlap,width-overlap) if y is None else y

    image = image_ori.copy()
    crop = image_ori.copy()
    crop = crop[random_y:random_y+height, random_x:random_x+width]
    image[random_y + overlap:random_y+height - overlap, random_x + overlap:random_x+width - overlap, :] = np.array([2*117./255.-1, 2*104./255.-1, 2*123./255.-1])

    return image, crop, random_x, random_y

In [None]:
img  =  load_image('dl_project-1/svt1/train/01_16.jpg')
img = (img + 1) / 2
img, _, _, _ = crop_central(img,x=32,y=32)

plt.imshow(img)
plt.show()

## MODEL

In [None]:
class Model():
    def __init__(self):
           pass
   
    def conv_layer( self, bottom, filter_shape, activation=tf.identity, padding='SAME', stride=1, name=None ):
        with tf.name_scope(name):
            w = tf.Variable(tf.random.normal(filter_shape, mean=0.0, stddev=0.005), name="W")
            b = tf.Variable(tf.zeros(filter_shape[-1]), name="b")

            conv = tf.nn.conv2d(bottom, w, strides=[1, stride, stride, 1], padding=padding)
            bias = activation(tf.nn.bias_add(conv, b))
        

        return bias 

    def deconv_layer(self, bottom, filter_shape, output_shape, activation=tf.identity, padding='SAME', stride=1, name=None):
        with tf.name_scope(name):
            W = tf.Variable(tf.random.normal(filter_shape, mean=0.0, stddev=0.005), name="W")
            b = tf.Variable(tf.zeros(filter_shape[-2]), name="b")

            deconv = tf.nn.conv2d_transpose(bottom, W, output_shape, strides=[1, stride, stride, 1], padding=padding)
            bias = activation(tf.nn.bias_add(deconv, b))
       
        return bias

    def fully_conn_layer( self, bottom, output_size, name ):
        shape = bottom.get_shape().as_list()
        dim = np.prod( shape[1:] )
        x = tf.reshape( bottom, [-1, dim])
        input_size = dim

        with tf.name_scope(name):
            w = tf.Variable(tf.random.normal([input_size, output_size], mean=0.0, stddev=0.005), name="W")
            b = tf.Variable(tf.zeros([output_size]), name="b")

            fully_conn = tf.compat.v1.nn.bias_add( tf.matmul(x, w), b)

        return fully_conn

    # (7x7x512)
    def channel_wise_layer(self, input, name): 
        _, width, height, n_feat_map = input.get_shape().as_list()
        input_reshape = tf.reshape( input, [-1, width*height, n_feat_map] )
        input_transpose = tf.transpose( input_reshape, [2,0,1] )

        with tf.name_scope(name):
            W = tf.Variable(tf.random.normal([n_feat_map, width * height, width * height], mean=0.0, stddev=0.005), name="W")
            output = tf.matmul(input_transpose, W)

        output_transpose = tf.transpose(output, [1,2,0])
        output_reshape = tf.reshape( output_transpose, [-1, height, width, n_feat_map] )

        return output_reshape

    def leaky_relu(self, bottom, leak=0.1):
        return tf.maximum(leak*bottom, bottom)

    def batchnorm(self, bottom, is_train, epsilon=1e-8, name=None):
        bottom = tf.clip_by_value( bottom, -100., 100.)
        depth = bottom.get_shape().as_list()[-1]

        with tf.compat.v1.variable_scope(name):

            gamma = tf.compat.v1.get_variable("gamma", [depth], initializer=tf.constant_initializer(1.))
            beta  = tf.compat.v1.get_variable("beta" , [depth], initializer=tf.constant_initializer(0.))

            batch_mean, batch_var = tf.nn.moments(bottom, [0,1,2], name='moments')
            exp_ma = tf.compat.v1.train.ExponentialMovingAverage(decay=0.5)


            def update():
                with tf.control_dependencies([ema_op]):
                    return tf.identity(batch_mean), tf.identity(batch_var)

            ema_op = exp_ma.apply([batch_mean, batch_var])
            ema_mean, ema_var = exp_ma.average(batch_mean), exp_ma.average(batch_var)
            mean, var = tf.cond(
                    is_train,
                    update,
                    lambda: (ema_mean, ema_var) )

            normalized = tf.compat.v1.nn.batch_norm_with_global_normalization(bottom, mean, var, beta, gamma, epsilon, False)
        return normalized

    def reconstruction_loss( self, images, is_train ):
        batch_size = images.get_shape().as_list()[0]

        with tf.compat.v1.variable_scope('GENERATOR3',reuse=tf.compat.v1.AUTO_REUSE):
            conv_layer1 = self.conv_layer(images, [4,4,3,64], stride=2, name="c1" )
            batch_norm1 = self.leaky_relu(self.batchnorm(conv_layer1, is_train, name='BN1'))
            conv_layer2 = self.conv_layer(batch_norm1, [4,4,64,64], stride=2, name="c2" )
            batch_norm2 = self.leaky_relu(self.batchnorm(conv_layer2, is_train, name='BN2'))
            conv_layer3 = self.conv_layer(batch_norm2, [4,4,64,128], stride=2, name="c3")
            batch_norm3 = self.leaky_relu(self.batchnorm(conv_layer3, is_train, name='BN3'))
            conv_layer4 = self.conv_layer(batch_norm3, [4,4,128,256], stride=2, name="c4")
            batch_norm4 = self.leaky_relu(self.batchnorm(conv_layer4, is_train, name='BN4'))
            conv_layer5 = self.conv_layer(batch_norm4, [4,4,256,512], stride=2, name="c5")
            batch_norm5 = self.leaky_relu(self.batchnorm(conv_layer5, is_train, name='BN5'))
            conv_layer6 = self.conv_layer(batch_norm5, [4,4,512,4000], stride=2, padding='VALID', name='c6')
            batch_norm6 = self.leaky_relu(self.batchnorm(conv_layer6, is_train, name='BN6'))

            de_conv_layer4 = self.deconv_layer(batch_norm6, [4,4,512,4000], conv_layer5.get_shape().as_list(), padding='VALID', stride=2, name="d4")
            de_batch_norm4 = tf.nn.relu(self.batchnorm(de_conv_layer4, is_train, name='DBN4'))
            de_conv_layer3 = self.deconv_layer(de_batch_norm4, [4,4,256,512], conv_layer4.get_shape().as_list(), stride=2, name="d3")
            de_batch_norm3 = tf.nn.relu(self.batchnorm(de_conv_layer3, is_train, name='DBN3'))
            de_conv_layer2 = self.deconv_layer(de_batch_norm3, [4,4,128,256], conv_layer3.get_shape().as_list(), stride=2, name="d2")
            de_batch_norm2 = tf.nn.relu(self.batchnorm(de_conv_layer2, is_train, name='DBN2'))
            de_conv_layer1 = self.deconv_layer(de_batch_norm2, [4,4,64,128], conv_layer2.get_shape().as_list(), stride=2, name="d1")
            de_batch_norm1 = tf.nn.relu(self.batchnorm(de_conv_layer1, is_train, name='DBN1'))
            recon = self.deconv_layer(de_batch_norm1, [4,4,3,64], [batch_size,64,64,3], stride=2, name="recon")

        return batch_norm1, batch_norm2, batch_norm3, batch_norm4, batch_norm5, batch_norm6, de_batch_norm4, de_batch_norm3, de_batch_norm2, de_batch_norm1, recon, tf.nn.tanh(recon)

    def adversarial_loss(self, images, is_train, reuse=True):
        with tf.compat.v1.variable_scope('DISCRIMINATOR3', reuse=tf.compat.v1.AUTO_REUSE):
            conv_layer1 = self.conv_layer(images, [4,4,3,64], stride=2, name="c1" )
            batch_norm1 = self.leaky_relu(self.batchnorm(conv_layer1, is_train, name='BN1'))
            conv_layer2 = self.conv_layer(batch_norm1, [4,4,64,128], stride=2, name="c2")
            batch_norm2 = self.leaky_relu(self.batchnorm(conv_layer2, is_train, name='BN2'))
            conv_layer3 = self.conv_layer(batch_norm2, [4,4,128,256], stride=2, name="c3")
            batch_norm3 = self.leaky_relu(self.batchnorm(conv_layer3, is_train, name='BN3'))
            conv_layer4 = self.conv_layer(batch_norm3, [4,4,256,512], stride=2, name="c4")
            batch_norm4 = self.leaky_relu(self.batchnorm(conv_layer4, is_train, name='BN4'))

            output = self.fully_conn_layer(batch_norm4, output_size=1, name='output')

        return output[:,0]

## TRAIN IMAGES  

In [None]:
n_epochs = 50
learning_rate_val = 0.0003
weight_decay_rate =  0.0001
momentum = 0.9
batch_size = 100
lambda_recon = 0.999
lambda_adv = 0.001
overlap_size = 7
hiding_size = 64

In [None]:
trainset_path = 'dl_project-1/svt1/train.pickle'
testset_path  = 'dl_project-1/svt1/test.pickle'
dataset_path = 'dl_project-1/svt1/'
model_path = 'dl_project-1/models/'
result_path= 'dl_project-1/results/'
pretrained_model_path = None

In [None]:
if not os.path.exists(model_path):
    os.makedirs( model_path )

if not os.path.exists(result_path):
    os.makedirs( result_path )

if not os.path.exists( trainset_path ) or not os.path.exists( testset_path ):
    
    trainset_dir = os.path.join( dataset_path, 'train' )
    testset_dir = os.path.join( dataset_path, 'test' )

    trainset = pd.DataFrame({'image_path': map(lambda x: os.path.join( trainset_dir, x ), os.listdir(trainset_dir))})
    testset = pd.DataFrame({'image_path': map(lambda x: os.path.join( testset_dir, x ), os.listdir(testset_dir))})

    trainset.to_pickle( trainset_path )
    testset.to_pickle( testset_path )
else:
    trainset = pd.read_pickle( trainset_path )
    testset = pd.read_pickle( testset_path )


In [None]:
print(len(testset))
testset.index = range(len(testset))
testset = testset.iloc[np.random.permutation(len(testset))]
is_train = tf.compat.v1.placeholder(tf.bool )

learning_rate = tf.compat.v1.placeholder(tf.float32, [])
images_tf = tf.compat.v1.placeholder(tf.float32, [batch_size, 128, 128, 3], name="images")
print(images_tf)
labels_D = tf.concat([tf.ones([batch_size]), tf.zeros([batch_size])], axis=0)
labels_G = tf.ones([batch_size])
print(labels_D,labels_G)
images_hiding = tf.compat.v1.placeholder( tf.float32, [batch_size, hiding_size, hiding_size, 3], name='images_hiding')

In [None]:
model = Model()

batch_norm1, batch_norm2, batch_norm3, batch_norm4, batch_norm5, batch_norm6, de_batch_norm4, de_batch_norm3, de_batch_norm2, de_batch_norm1, reconstruction_ori, reconstruction = model.reconstruction_loss(images_tf, is_train)
adversarial_pos = model.adversarial_loss(images_hiding, is_train)
adversarial_neg = model.adversarial_loss(reconstruction, is_train, reuse=True)

adversarial_pos = tf.expand_dims(adversarial_pos, axis=0)
adversarial_neg = tf.expand_dims(adversarial_neg, axis=0)
adversarial_all = tf.concat([adversarial_pos, adversarial_neg], axis=0)
tf.rank(adversarial_all)
print(adversarial_all.shape)



In [None]:

masking_reconst = tf.pad(tf.ones([hiding_size - 2*overlap_size, hiding_size - 2*overlap_size]), [[overlap_size,overlap_size], [overlap_size,overlap_size]])

masking_reconst = tf.expand_dims(masking_reconst, axis=-1)
print(masking_reconst.shape)
masking_reconst = tf.concat([masking_reconst]*3, axis=-1)
masking_overlap = 1 - masking_reconst

In [None]:
recon_loss_ori = tf.square( images_hiding - reconstruction )
recon_loss_central = tf.reduce_mean(tf.sqrt( 1e-5 + tf.reduce_sum(recon_loss_ori * masking_reconst, [1,2,3])))  # Loss for non-overlapping region
recon_loss_overlap = tf.reduce_mean(tf.sqrt( 1e-5 + tf.reduce_sum(recon_loss_ori * masking_overlap, [1,2,3]))) * 10. # Loss for overlapping region
recon_loss = recon_loss_central + recon_loss_overlap

adversarial_all = tf.reshape(adversarial_all, [-1])
adversarial_neg = tf.reshape(adversarial_neg, [-1])
loss_adv_D = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(adversarial_all, labels_D))
loss_adv_G = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(adversarial_neg, labels_G))

loss_G = loss_adv_G * lambda_adv + recon_loss * lambda_recon
loss_D = loss_adv_D 

In [None]:

var_G = list(filter(lambda x: x.name.startswith('GENERATOR'), tf.compat.v1.trainable_variables()))
var_D = list(filter(lambda x: x.name.startswith('DISCRIMINATOR'), tf.compat.v1.trainable_variables()))

W_G = list(filter(lambda x: x.name.endswith('W:0'), var_G))
W_D = list(filter(lambda x: x.name.endswith('W:0'), var_D))

loss_G += weight_decay_rate * tf.reduce_mean(tf.stack(list(map(lambda x: tf.nn.l2_loss(x), W_G))))
loss_D += weight_decay_rate * tf.reduce_mean(tf.stack(list(map(lambda x: tf.nn.l2_loss(x), W_D))))

gpu = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=0.4)
sess = tf.compat.v1.InteractiveSession(config=tf.compat.v1.ConfigProto(gpu_options=gpu))

optimizer_G = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate, beta1=0.5)
grads_vars_G = optimizer_G.compute_gradients(loss_G, var_list=var_G)
train_op_G = optimizer_G.apply_gradients(grads_vars_G)
optimizer_D = tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate, beta1=0.5 )
grads_vars_D = optimizer_D.compute_gradients( loss_D, var_list=var_D )
train_op_D = optimizer_D.apply_gradients( grads_vars_D )
save_model = tf.compat.v1.train.Saver(max_to_keep=100)
# substitute to keras
tf.compat.v1.global_variables_initializer().run()

#sess.close()

In [None]:
tf.config.list_physical_devices('GPU')

In [None]:
if pretrained_model_path is not None and os.path.exists( pretrained_model_path ):
    save_model.restore( sess, pretrained_model_path )

i = 0

loss_D_val = 0.
loss_G_val = 0.
loss_G_ad_rec = []
loss_D_adv = []
print(len(trainset))
for epoch in range(n_epochs):
    trainset.index = range(len(trainset))
    trainset = trainset.iloc[np.random.permutation(len(trainset))]

    for start,end in zip(
            range(0, len(trainset), batch_size),
            range(batch_size, len(trainset), batch_size)):

        image_paths = trainset[start:end]['image_path'].values
        images_ori = map(lambda x: load_image( x ), image_paths)

        images_crops = map(lambda x: crop_central(x), images_ori)
        images, crops,_,_ = zip(*images_crops)


        if i % 5 == 0:
            test_image_paths = testset[:batch_size]['image_path'].values
            test_images_ori = map(lambda x: load_image(x), test_image_paths)

            test_images_crop = map(lambda x: crop_central(x, x=32, y=32), test_images_ori)
            test_images, test_crops, xs,ys = zip(*test_images_crop)

            reconstruction_vals, recon_ori_vals, BN1_val,BN2_val,BN3_val,BN4_val,\
            BN5_val,BN6_val,deBN4_val, deBN3_val, deBN2_val, deBN1_val, loss_G_val, loss_D_val = sess.run(
                    [reconstruction, reconstruction_ori, batch_norm1,batch_norm2,batch_norm3,batch_norm4,batch_norm5,batch_norm6,de_batch_norm4, de_batch_norm3, de_batch_norm2, de_batch_norm1, loss_G, loss_D],
                    feed_dict={
                        images_tf: test_images,
                        images_hiding: test_crops,
                        is_train: False
                        })

            # Get result after 5 iterations
            if i % 5 == 0:
                img_no = 0
                for rec_val, img,x,y in zip(reconstruction_vals, test_images, xs, ys):
                    rec_hid = (255. * (rec_val+1)/2.).astype(int)
                    rec_con = (255. * (img+1)/2.).astype(int)

                    rec_con[y:y+64, x:x+64] = rec_hid
                    cv2.imwrite( os.path.join(result_path, 'img_'+str(img_no)+'.'+str(int(i/1000))+'.jpg'), rec_con)
                    img_no += 1
                    if img_no > 50: break

            print("Loss parameter Values at each operation: ")

            print(BN1_val.max(), BN1_val.min())
            print(BN2_val.max(), BN2_val.min())
            print(BN3_val.max(), BN3_val.min())
            print(BN4_val.max(), BN4_val.min())
            print(BN5_val.max(), BN5_val.min())
            print(BN6_val.max(), BN6_val.min())

            print(deBN4_val.max(), deBN4_val.min())
            print(deBN3_val.max(), deBN3_val.min())
            print(deBN2_val.max(), deBN2_val.min())
            print(deBN1_val.max(), deBN1_val.min())

            print(recon_ori_vals.max(), recon_ori_vals.min())
            print(reconstruction_vals.max(), reconstruction_vals.min())
            print(loss_G_val, loss_D_val)
            print('\n')

            if np.isnan(reconstruction_vals.min() ) or np.isnan(reconstruction_vals.max()):
                print("NaN detected!!")
                ipdb.set_trace()

       
        _, loss_G_val, adv_pos_val, adv_neg_val, loss_recon_val, loss_adv_G_val, \
            reconstruction_vals, recon_ori_vals, BN1_val,BN2_val,BN3_val,BN4_val,BN5_val,BN6_val, \
                deBN4_val, deBN3_val, deBN2_val, deBN1_val = sess.run(
                [train_op_G, loss_G, adversarial_pos, adversarial_neg,recon_loss, loss_adv_G, reconstruction, reconstruction_ori, \
                batch_norm1,batch_norm2,batch_norm3,batch_norm4,batch_norm5,batch_norm6,\
                de_batch_norm4, de_batch_norm3, de_batch_norm2, de_batch_norm1],
                feed_dict={
                    images_tf: images,
                    images_hiding: crops,
                    learning_rate: learning_rate_val,
                    is_train: True
                    })

        loss_G_ad_rec.append(loss_G_val)
        loss_D_adv.append(loss_D_val)
        # Discriminator of GAN is updated only once in 10 iterations
        if i % 5  == 0:
            _, loss_D_val, adv_pos_val, adv_neg_val = sess.run(
                    [train_op_D, loss_D, adversarial_pos, adversarial_neg],
                    feed_dict={
                        images_tf: images,
                        images_hiding: crops,
                        learning_rate: learning_rate_val,
                        is_train: True
                        })

            print("Iteration:", i, "Generator Loss:", loss_G_val, "Reconconstuction Loss:", loss_recon_val, "Gen Adversarial Loss:", loss_adv_G_val,  "Discriminator Loss:", loss_D_val, "==", adv_pos_val.mean(), adv_neg_val.min(), adv_neg_val.max())

        i += 1


    save_model.save(sess, model_path + 'model', global_step=epoch)
    learning_rate_val *= 0.99
plt.plot(range(len(loss_D_adv)), loss_D_adv, label='Adversarial Loss')
plt.plot(range(len(loss_G_ad_rec)), loss_G_ad_rec, label='Reconstruction + Adversarial Loss')
#plt.plot(range(len(recon_losses)), recon_losses, label='Reconstruction Loss')
plt.legend()
plt.show()


## TEST IMAGES

In [None]:
n_epochs = 50
learning_rate_val =  0.0003
weight_decay_rate = 0.0001
momentum = 0.9
batch_size = 100
lambda_recon = 0.999
lambda_adv = 0.001
overlap_size = 7
hiding_size = 64

In [None]:
ckpt_path = "dl_project-1/models/checkpoint"

vars_list = tf.train.list_variables(ckpt_path)
for var in vars_list:
    print(var)

In [None]:
checkpoint_path = "dl_project-1/models/checkpoint"
if os.path.exists(checkpoint_path):
    print("Checkpoint file exists.")
else:
    print("Checkpoint file does not exist.")

try:
    reader = tf.train.load_checkpoint(checkpoint_path)
    var_to_shape_map = reader.get_variable_to_shape_map()
except Exception as e:
    print("Error reading checkpoint file: ", str(e))
else:
    # Check the integrity of the variables
    try:
        tf.train.init_from_checkpoint(checkpoint_path, assignment_map=var_to_shape_map)
    except Exception as e:
        print("Integrity check failed: ", str(e))
    else:
        print("Integrity check successful.")
# saver = tf.compat.v1.train.Saver()
# with tf.compat.v1.Session() as sess:
#     # Restore the checkpoint
#     saver.restore(sess, 'dl_project-1/models/checkpoint')


In [None]:
testset_path  = 'dl_project-1/svt1/test.pickle'
result_path= 'dl_project-1/results/results_test/'
pretrained_model_path = 'dl_project-1/models/model-99'
testset = pd.read_pickle( testset_path )

is_train = tf.compat.v1.placeholder( tf.bool )
images_tf = tf.compat.v1.placeholder( tf.float32, [batch_size, 128, 128, 3], name="images")


In [None]:
model = Model()

reconstruction = model.reconstruction_loss(images_tf, is_train)


tf.compat.v1.disable_v2_behavior()

sess = tf.compat.v1.InteractiveSession()

tf.compat.v1.global_variables_initializer().run()
save_model.restore( sess, pretrained_model_path )

n = 0
for start,end in zip(
        range(0, len(testset), batch_size),
        range(batch_size, len(testset), batch_size)):
    
    test_image_paths = testset[:batch_size]['image_path'].values
    test_images_ori = map(lambda x: load_image(x), test_image_paths)

    test_images_crop = map(lambda x: crop_central(x, x=32, y=32), test_images_ori)
    test_images, test_crops, xs,ys = zip(*test_images_crop)
    for img,x,y in zip(test_images, xs, ys):
        img_rgb = (255. * (img + 1)/2.).astype(int)
        cv2.imwrite(os.path.join(result_path, 'img_'+str(n)+'.ori.jpg'), img_rgb)
        n+=1
        if n>30: 
            break