In [2]:
import tensorflow as tf
from tensorflow import keras 
import numpy as np
physical_devices = tf.config.list_physical_devices('GPU') 
tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [3]:
from  utils.dataset import Dataset

In [4]:
dataset = Dataset(image_size=(64,64),data_base_path="./data")

In [5]:
train_ds = dataset.get_train_ds()

In [19]:
import tensorflow as tf
from tensorflow.keras import backend as  K 
from tensorflow import keras
from tensorflow.keras import layers as L
import time
############################################################
# Conditioning Augmentation Network
############################################################
class ConditioningAugmentation(keras.Model):
    def __init__(self, *args, **kwargs):
        super(ConditioningAugmentation, self).__init__(*args, **kwargs)
        self.dense = L.Dense(256)
        self.activation = L.LeakyReLU(alpha=0.2)
        

    def call(self,input):
        x  = self.dense(input)
        phi= self.activation(x)
        mean = phi[:,:128]
        std =tf.math.exp(phi[:,128:])
        epsilon = K.random_normal(shape = K.constant((mean.shape[1], ), dtype = 'int32'))
        output = mean + std*epsilon
        return output,phi

class EmbeddingCompresssor(keras.Model):
    def __init__(self):
        super(EmbeddingCompresssor, self).__init__()
        self.dense = L.Dense(128)

    def call(self,input):
        x = self.dense(input)
        x = L.LeakyReLU(0.2)(x)
        return x


############################################################
# Stage 1 Generator Network (CGAN)
############################################################


def UpSamplingBlock(input,num_filters):
    x = L.UpSampling2D(size=2)(input)
    x = L.Conv2D(num_filters,kernel_size=3,padding='same',strides=1,use_bias=False)(x)
    x = L.BatchNormalization()(x)
    x = L.ReLU()(x)
    return x


class Stage1Generator(keras.Model):
    def __init__(self):
        super(Stage1Generator, self).__init__()
        self.augmentation = ConditioningAugmentation()
        self.concat = L.Concatenate(axis=1)
        self.dense = tf.keras.layers.Dense(units = 128*8*4*4, kernel_initializer = tf.random_normal_initializer(stddev = 0.02))
        self.reshape = tf.keras.layers.Reshape(target_shape = (4, 4, 128*8), input_shape = (128*8*4*4, ))
        self.batchnorm1 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
        self.activation = L.ReLU()

    def call(self,inputs):
        embedding , noise = inputs
        c , phi = self.augmentation(embedding)
        gen_input = self.concat([c,noise])
        
        x = self.dense(gen_input)
        x = self.reshape(x)
        x = self.batchnorm1(x)
        x = self.activation(x)
        x = UpSamplingBlock(x, 512)
        x = UpSamplingBlock(x, 256)
        x = UpSamplingBlock(x,128)
        x = UpSamplingBlock(x,3)
        x = L.Conv2D(3,kernel_size=3,padding='same')(x)
        x = L.Activation('tanh')(x)

        return x,phi

class Stage1Discriminator(keras.Model):
    def __init__(self,*args, **kwargs):
        super(Stage1Discriminator, self).__init__(*args,**kwargs)
        
        self.l1 = L.Conv2D(64,kernel_size=4,strides=2,padding='same',kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev = 0.02 ))
        self.l2 = L.Conv2D(128,kernel_size=4,strides=2,padding='same',kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev =0.02))
        self.l3 = L.BatchNormalization(axis = -1)
        self.l4 = L.Conv2D(256,kernel_size=4,strides=2,padding='same',kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev = 0.02))
        self.l5 = L.BatchNormalization(axis = -1)
        self.l6 = L.Conv2D(516,kernel_size=4,strides=2,padding='same',kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev = 0.02))
        self.l7 = L.BatchNormalization(axis = -1)
        self.embedding = EmbeddingCompresssor()
        self.l9 = L.Reshape(target_shape = (1,1,128))
        self.concat= L.Concatenate() 
        self.l11= L.Conv2D(filters = 1024, kernel_size=4,strides=2,padding='same',kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev = 0.02))
        self.l12= L.BatchNormalization(axis = -1)
        self.l13= L.Conv2D(filters = 1, kernel_size=4,strides=2,padding='same',kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev = 0.02))
        

    def call(self,inputs):
        I , E = inputs #
        x  = self.l1(I)
        x  = self.l2(x)
        x  = self.l3(x)
        x  = self.l4(x)
        x  = self.l5(x)
        x  = self.l6(x)
        x  = self.l7(x)

        t  = self.embedding(E)
        t  = self.l9(t)
        t  = tf.tile(t,(1,4,4,1))

        merged_input  = self.concat([t,x])

        y = self.l11(merged_input)
        y =self.l12(y)
        y =L.LeakyReLU()(y)

        y = self.l13(y)

        return tf.squeeze(y)


def KL_loss(y_true, y_pred):
  mean = y_pred[:, :128]
  logsigma = y_pred[:, 128:]
  loss = -logsigma + 0.5*(-1 + K.exp(2.0*logsigma) + K.square(mean))
  loss = K.mean(loss)
  return loss

class Stage1Model(keras.Model):
    def __init__(self):
        super(Stage1Model, self).__init__()
        self.generator = Stage1Generator() 
        self.discriminator= Stage1Discriminator()
        self.generator_optimizer = keras.optimizers.Adam(0.001)
        self.discriminator_optimizer = keras.optimizers.Adam(0.001)
        self.noise_dim = 100 #
        self.c_dim = 128 
        self.loss = {}

    def train(self,train_data,num_epochs=10,steps_per_epoch=125):
        noise_dim = self.noise_dim
        for epoch in range(num_epochs):
            print("Epoch %d/%d:\n |"%(epoch+1,num_epochs),end="")
            start_time = time.time()
            if epoch % 100 == 0:
                K.set_value(self.generator_optimizer.learning_rate, self.generator_optimizer.learning_rate / 2)
                K.set_value(self.discriminator_optimizer.learning_rate, self.discriminator_optimizer.learning_rate / 2)
            
            generator_loss_log = []
            discriminator_loss_log =[]
            steps_per_epoch = steps_per_epoch
            # batch iterator Generator
            batch_iter = iter(train_data)
            steps_per_epoch_batch_len = steps_per_epoch//15

            batch_size = next(batch_iter)[0].shape[0]
            for i in range(steps_per_epoch):
                if steps_per_epoch % steps_per_epoch_batch_len == 0:
                    print("=",end="")
                image_batch, embedding_batch = next(batch_iter)
                noise = tf.random.normal(shape = (batch_size, noise_dim))

                mismatched_images = tf.roll(image_batch ,shift = 1,axis =0 )
                real_labels = tf.random.uniform(shape = (batch_size,),minval = 0.9 , maxval =1)
                fake_labels = tf.random.uniform(shape = (batch_size, ), minval = 0.0, maxval = 0.1)
                mismatched_labels = tf.random.uniform(shape = (batch_size,), minval=0.0 , maxval = 0.1)

                with tf.GradientTape() as generator_tape, tf.GradientTape() as discriminator_tape:
                    fake_images , phi = self.generator([embedding_batch,noise])
                    real_logits = self.discriminator([image_batch,embedding_batch])
                    fake_logits = self.discriminator([fake_images, embedding_batch])
                    mismatched_logits = self.discriminator([mismatched_images, embedding_batch])
                    
                    l_sup = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(real_labels,real_logits))  
                    l_klreg = KL_loss(tf.random.normal((phi.shape[0], phi.shape[1])), phi)
                    generator_loss = l_sup + 2.0*l_klreg

                    l_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(real_labels, real_logits))
                    l_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(fake_labels, fake_logits))
                    l_mismatched = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(mismatched_labels, mismatched_logits))
                    discriminator_loss = 0.5*tf.add(l_real, 0.5*tf.add(l_fake, l_mismatched))
                    
                generator_gradients = generator_tape.gradient(generator_loss,self.generator.trainable_variables)
                discriminator_gradients = discriminator_tape.gradient(discriminator_loss,self.discriminator.trainable_variables)
                self.generator_optimizer.apply_gradients(zip(generator_gradients, self.generator.trainable_variables))
                self.discriminator_optimizer.apply_gradients(zip(discriminator_gradients, self.discriminator.trainable_variables))
                
                generator_loss_log.append(generator_loss)
                discriminator_loss_log.append(discriminator_loss)

            end_time = time.time()

            if epoch % 1 == 0:
                epoch_time = end_time - start_time
                template = " - generator_loss: {:.4f} - discriminator_loss: {:.4f} - epoch_time: {:.2f} s"
                print(template.format(tf.reduce_mean(generator_loss_log), tf.reduce_mean(discriminator_loss_log), epoch_time))

            if (epoch + 1) % 10 == 0 or epoch == num_epochs - 1:
                import os
                save_path = "./lr_results/epoch_" + str(epoch + 1)
                temp_embeddings = None
                for _, embeddings in train_data:
                    temp_embeddings = embeddings.numpy()
                    break
                if os.path.exists(save_path) == False:
                    os.makedirs(save_path)
                
                temp_batch_size = 10
                temp_z_noise = tf.random.normal((temp_batch_size, self.noise_dim))
                temp_embedding_batch = temp_embeddings[0:temp_batch_size]
                fake_images, _ = self.generator([temp_embedding_batch, temp_z_noise])
               
                for i, image in enumerate(fake_images):
                    image = 127.5*image + 127.5
                    image = image.numpy().astype('uint8')
                    image = keras.preprocessing.image.array_to_img(image)
                    image.save(save_path + "/gen_%d.png"%(i))
                    # cv2.imwrite(save_path + "/gen_%d.png"%(i), image)
                
                self.generator.save_weights("./weights/stage1_generator_" + str(epoch + 1) + ".ckpt")
                self.discriminator.save_weights("./weights/stage1_discriminator_" + str(epoch + 1) + ".ckpt")
                
                
model = Stage1Model()

In [21]:
model.train(train_ds,num_epochs=1)

Epoch 1/1:


NameError: name 'os' is not defined