In [None]:
import tensorflow as tf
from tensorflow import keras 
import numpy as np
#from  utils.dataset import Dataset
from re import template
from numpy.lib.type_check import imag
from tensorflow.keras import backend as  K 
from tensorflow import keras
from tensorflow.keras import layers as L
import time
import pickle
import os
from kaggle_datasets import KaggleDatasets
from IPython.display import clear_output
from tensorflow.keras import losses

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)

AUTOTUNE = tf.data.experimental.AUTOTUNE
    
print(tf.__version__)

In [None]:
GCS_DS_PATH = KaggleDatasets().get_gcs_path() # you can list the bucket with "!gsutil ls $GCS_DS_PATH"
AUTO = tf.data.experimental.AUTOTUNE

In [None]:
# NEW on TPU in TensorFlow 24: shorter cross-compatible TPU/GPU/multi-GPU/cluster-GPU detection code

try: # detect TPUs
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect() # TPU detection
    strategy = tf.distribute.TPUStrategy(tpu)
except ValueError: # detect GPUs
    strategy = tf.distribute.MirroredStrategy() # for GPU or multi-GPU machines
    #strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU
    #strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() # for clusters of multi-GPU machines

print("Number of accelerators: ", strategy.num_replicas_in_sync)

In [None]:
def load_text_embeddings(path):
    #Loading CNN-RNN text embeddings
    with open(path, 'rb') as file:
        embeds = pickle.load(file, encoding='latin1')
    return embeds

def load_images(path):
    #Loading images from pickle file
    with open(path, 'rb') as f_in:
        images = pickle.load(f_in)
    return images

def load_data(embeddings_path, pickle_data_file):
    #Load images and embeddings
    embeddings = np.array(load_text_embeddings(embeddings_path))
    x = np.array(load_images(pickle_data_file))        
    return x, embeddings

In [None]:
path = "../input/gandata20/birds"
train_path = path + "/train"
test_path = path + "/test"

embedding_train = train_path + "/char-CNN-RNN-embeddings.pickle"
embedding_test = test_path + "/char-CNN-RNN-embeddings.pickle"

pickle_train_low = train_path + "/64images.pickle"
pickle_test_low = test_path + "/64images.pickle"

pickle_train_high = train_path + "/256images.pickle"
pickle_test_high = test_path + "/256images.pickle"

In [None]:
#Low resolution images
x_train_low, train_embeds = load_data(embedding_train, pickle_train_low)
x_test_low, test_embeds = load_data(embedding_test, pickle_test_low)

In [None]:
x_train_low.shape

In [None]:
train_embedding =train_embeds[:,0]
train_embedding.shape

In [None]:
x_train_low = tf.cast(x_train_low,tf.float32)
train_embedding = tf.cast(train_embedding,tf.float32)

train_ds = tf.data.Dataset.from_tensor_slices((x_train_low, train_embedding))

In [None]:
strategy.num_replicas_in_sync

In [None]:
BATCH_SIZE = 32 * 4
train_ds= train_ds.batch(BATCH_SIZE)

In [None]:
#train_ds = train_ds.prefetch(1)

In [None]:
import matplotlib.pyplot as plt
def show_image(x):
    img = plt.figure()
    ax = img.add_subplot(1,1,1)
    ax.imshow(x) 

In [None]:
# dataset = Dataset(image_size=(64,64),data_base_path="./data",batch_size=128)
# train_ds = dataset.get_train_ds()

In [None]:

############################################################
# Conditioning Augmentation Network
############################################################
class ConditioningAugmentation(keras.Model):
    def __init__(self, *args, **kwargs):
        super(ConditioningAugmentation, self).__init__(*args, **kwargs)
        self.dense = L.Dense(256)
        self.activation = L.LeakyReLU(alpha=0.2)
        

    def call(self,input):
        x  = self.dense(input)
        phi= self.activation(x)
        mean = phi[:,:128]
        std =tf.math.exp(phi[:,128:])
        epsilon = K.random_normal(shape = K.constant((mean.shape[1], ), dtype = 'int32'))
        output = mean + std*epsilon
        return output,phi

class EmbeddingCompresssor(keras.Model):
    def __init__(self):
        super(EmbeddingCompresssor, self).__init__()
        self.dense = L.Dense(128)

    def call(self,input):
        x = self.dense(input)
        x = L.LeakyReLU(0.2)(x)
        return x


############################################################
# Stage 1 Generator Network (CGAN)
############################################################


def UpSamplingBlock(input,num_filters):
    x = L.UpSampling2D(size=2)(input)
    x = L.Conv2D(num_filters,kernel_size=3,padding='same',strides=1,use_bias=False)(x)
    x = L.BatchNormalization()(x)
    x = L.ReLU()(x)
    return x


class Stage1Generator(keras.Model):
    def __init__(self):
        super(Stage1Generator, self).__init__(name='stage_1_generator')
        self.augmentation = ConditioningAugmentation()
        self.concat = L.Concatenate(axis=1)
        self.dense = tf.keras.layers.Dense(units = 128*8*4*4, kernel_initializer = tf.random_normal_initializer(stddev = 0.02))
        self.reshape = tf.keras.layers.Reshape(target_shape = (4, 4, 128*8), input_shape = (128*8*4*4, ))
        self.batchnorm1 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
        self.activation = L.ReLU()

    def call(self,inputs):
        embedding , noise = inputs
        c , phi = self.augmentation(embedding)
        gen_input = self.concat([c,noise])
        x = self.dense(gen_input)
        x = self.reshape(x)
        x = self.batchnorm1(x)
        x = self.activation(x)
        x = UpSamplingBlock(x, 512)
        x = UpSamplingBlock(x, 256)
        x = UpSamplingBlock(x,128)
        x = UpSamplingBlock(x,3)
        x = L.Conv2D(3,kernel_size=3,padding='same')(x)
        x = L.Activation('tanh')(x)

        return x,phi

class Stage1Discriminator(keras.Model):
    def __init__(self,*args, **kwargs):
        super(Stage1Discriminator, self).__init__(*args,**kwargs)
        self.l1 = L.Conv2D(64,kernel_size=4,strides=2,padding='same',kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev = 0.02 ))
        self.l2 = L.Conv2D(128,kernel_size=4,strides=2,padding='same',kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev =0.02))
        self.l3 = L.BatchNormalization(axis = -1)
        self.l4 = L.Conv2D(256,kernel_size=4,strides=2,padding='same',kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev = 0.02))
        self.l5 = L.BatchNormalization(axis = -1)
        self.l6 = L.Conv2D(516,kernel_size=4,strides=2,padding='same',kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev = 0.02))
        self.l7 = L.BatchNormalization(axis = -1)
        self.embedding = EmbeddingCompresssor()
        self.l9 = L.Reshape(target_shape = (1,1,128))
        self.concat= L.Concatenate() 
        self.l11= L.Conv2D(filters = 1024, kernel_size=4,strides=2,padding='same',kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev = 0.02))
        self.l12= L.BatchNormalization(axis = -1)
        self.l13= L.Conv2D(filters = 1
                            ,kernel_size=4,strides=2,padding='same',kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev = 0.02))
        

    def call(self,inputs):
        I , E = inputs #
        x  = self.l1(I)
        x  = self.l2(x)
        x  = self.l3(x)
        x  = self.l4(x)
        x  = self.l5(x)
        x  = self.l6(x)
        x  = self.l7(x)

        t  = self.embedding(E)
        t  = self.l9(t)
        t  = tf.tile(t,(1,4,4,1))

        merged_input  = self.concat([t,x])

        y = self.l11(merged_input)
        y =self.l12(y)
        y =L.LeakyReLU()(y)

        y = self.l13(y)
        y = L.Activation('sigmoid')(y)
        y = tf.squeeze(y)
        return y


def KL_loss(y_true, y_pred):
    kl = tf.keras.losses.KLDivergence()
    loss =kl(y_true, y_pred)
    loss = tf.reduce_mean(loss)
    
    return loss

def loss(true_label,predicted_label):
    #print(f"true_label :{true_label} predicted_label {predicted_label}")
    #loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(true_label, predicted_label))
    loss=KL_loss(true_label,predicted_label)
    return loss
    
def predict(x,true):
    if true == 0 and x <0.5:
        return 1
    if true == 1 and x>0.5:
        return 1
    return 0
        
def accuracy(y_true, y_pred):
    acc =0
    if y_true[0]> 0.5:
        acc = tf.reduce_mean(tf.map_fn(fn = lambda x : predict(x,1) ,elems=y_pred))
    else :
        acc = tf.reduce_mean(tf.map_fn(fn = lambda x : predict(x,0) ,elems=y_pred))
    return acc

class Stage1Model(tf.keras.Model):
  def __init__(self,lr_g=0.0001,lr_d=0.0001):
    super(Stage1Model, self).__init__()
    self.generator = Stage1Generator()
    self.discriminator = Stage1Discriminator()
    self.generator_optimizer = keras.optimizers.Adam(learning_rate= lr_g, beta_1= 0.5 , beta_2= 0.999)
    self.discriminator_optimizer = keras.optimizers.Adam(learning_rate= lr_d, beta_1= 0.5 , beta_2= 0.999)
    self.noise_dim = 100
    self.c_dim = 128
    self.loss = {}


  def load_weights(self,path):
    z_noise = tf.random.normal((1, self.noise_dim))
    embedding = tf.random.normal((1,1024))
    image, phi = self.generator([embedding, z_noise])
    logit = self.discriminator([image,embedding])
    self.generator.load_weights(path+"/stage1_generator.h5")
    self.discriminator.load_weights(path+"/stage1_discriminator.h5")

  def train(self, train_ds, batch_size = 64, num_epochs = 600,save_weights_epoch=5,train_length=8855):
    for epoch in range(num_epochs):
      print("Epoch %d/%d:\n "%(epoch + 1, num_epochs), end = "")
      start_time = time.time()
      if epoch % 100 == 0:
        K.set_value( self.generator_optimizer.learning_rate,  self.generator_optimizer.learning_rate / 2)
        K.set_value( self.generator_optimizer.learning_rate,  self.generator_optimizer.learning_rate / 2)
    
      generator_loss_log = []
      discriminator_loss_log = []
      steps_per_epoch = train_length//batch_size
      batch_iter = iter(train_ds)
      d_acc =0
      g_acc =0
      
      for i in range(steps_per_epoch):
        if i % 5 == 0:
          print("=", end = "")
        image_batch, embedding_batch = next(batch_iter)
        batch_size = image_batch.shape[0]
        z_noise = tf.random.normal((batch_size, self.noise_dim))

        mismatched_images = tf.roll(image_batch, shift = 1, axis = 0)

        real_labels = tf.random.uniform(shape = (batch_size, ), minval = 0.9, maxval = 1.0)
        fake_labels = tf.random.uniform(shape = (batch_size, ), minval = 0.0, maxval = 0.1)
        mismatched_labels = tf.random.uniform(shape = (batch_size, ), minval = 0.0, maxval = 0.1)
        
        
        with tf.GradientTape() as generator_tape, tf.GradientTape() as discriminator_tape:
          fake_images,phi = self.generator([embedding_batch, z_noise])
          real_logits = self.discriminator([image_batch, embedding_batch])
          fake_logits = self.discriminator([fake_images, embedding_batch])
          mismatched_logits = self.discriminator([mismatched_images, embedding_batch])

          l_sup = loss(real_labels, fake_logits)
          l_klreg = KL_loss(tf.random.normal((phi.shape[0], phi.shape[1])), phi)
          generator_loss = l_sup #+ 2.0*l_klreg
          g_acc =g_acc+accuracy(real_labels, fake_logits)/batch_size
            
          l_real = loss(real_labels, real_logits)
          l_fake = loss(fake_labels, fake_logits)
          l_mismatched = loss(mismatched_labels, mismatched_logits)
          discriminator_loss = 0.5*tf.add(l_real, 0.5*tf.add(l_fake, l_mismatched))
            
          d_acc = d_acc+ (accuracy(real_labels, real_logits)+accuracy(fake_labels, fake_logits) +accuracy(mismatched_labels, mismatched_logits))/(3*batch_size)
       
        generator_gradients = generator_tape.gradient(generator_loss, self.generator.trainable_variables)
        discriminator_gradients = discriminator_tape.gradient(discriminator_loss, self.discriminator.trainable_variables)
        
        self.generator_optimizer.apply_gradients(zip(generator_gradients, self.generator.trainable_variables))
        self.discriminator_optimizer.apply_gradients(zip(discriminator_gradients, self.discriminator.trainable_variables))
        
        generator_loss_log.append(generator_loss)
        discriminator_loss_log.append(discriminator_loss)
        #break

      end_time = time.time()

      if epoch % 1 == 0:
        epoch_time = end_time - start_time
        template = " - generator_loss: {:.4f}- generator accuracy: {:.4f} - discriminator_loss: {:.4f} - discriminator accuracy: {:.4f} - epoch_time: {:.2f} "
        print(template.format(tf.reduce_mean(generator_loss_log),g_acc, tf.reduce_mean(discriminator_loss_log), d_acc,epoch_time))

      if (epoch + 1) % save_weights_epoch == 0 or epoch == num_epochs - 1:
        save_path = "./lr_results/epoch_" + str(epoch + 1)
        temp_embeddings = None
        for _, embeddings in train_ds:
          temp_embeddings = embeddings.numpy()
          break
        if os.path.exists(save_path) == False:
          os.makedirs(save_path)
        temp_batch_size = 10
        temp_z_noise = tf.random.normal((temp_batch_size, self.noise_dim))
        temp_embedding_batch = temp_embeddings[0:temp_batch_size]
        fake_images, _ = self.generator([temp_embedding_batch, temp_z_noise])
        for i, image in enumerate(fake_images):
          image = 127.5*image + 127.5
          image = image.numpy().astype('uint8')
          image  = keras.preprocessing.image.array_to_img(image)
          image.save(save_path + "/gen_%d.png"%(i))

        weights_path = f"./weights/weights_{epoch+1}"

        if os.path.exists(weights_path)== False:
          os.makedirs(weights_path)
        self.generator.save_weights(weights_path+"/stage1_generator.h5")
        self.discriminator.save_weights(weights_path+"/stage1_discriminator.h5")
        clear_output(wait=True)

    
def generate_image(self, embedding, batch_size= 64):
        z_noise = tf.random.normal((batch_size, self.noise_dim))
        generated_image = self.generator([embedding, z_noise])
        return generated_image                


In [None]:
model = Stage1Model(lr_g=0.0001,lr_d=0.0001)
#model.load_weights(path="./weights/weights_20")
model.train(train_ds,batch_size=256,num_epochs=1,save_weights_epoch=20)

In [None]:
model.train(train_ds,batch_size=256,num_epochs=10,save_weights_epoch=20)

In [None]:
i=model.generate_image(train_embedding[0],batch_size=1)
show_image(i[0])

In [None]:
model.generator.summary()

In [None]:

dataset = Dataset(image_size=(256,256),data_base_path="./data",batch_size=1)
train_ds = dataset.get_test_ds()

In [None]:
len(dataset.test_filenames)

In [None]:
def ResidualBlock(input, num_filters):
    x  =  L.Conv2D(filters = num_filters, kernel_size= 3 , strides=1, padding='same')(input)
    x  =  L.BatchNormalization()(x)
    x  =  L.ReLU()(x)
    x  =  L.Conv2D(filters = num_filters, kernel_size=3, strides=1, padding='same')(x)
    x  =  L.BatchNormalization()(x)
    x  =  L.ReLU()(x)
    return x


class Stage2Generator(keras.Model):
    def __init__(self,*args, **kwargs):
        super(Stage2Generator, self).__init__(*args, **kwargs)

    def call(self,inputs):
        x = ResidualBlock(inputs, 128)
        x = ResidualBlock(x,256)
        x = UpSamplingBlock(x,256)
        x = ResidualBlock(x,128)
        x = UpSamplingBlock(x,256)
        x = ResidualBlock(x,3)
        return x

def DownSamplingBlock(  inputs,
                        num_filters, 
                        kernel_size= 4,
                        strides = 2,
                        batch_norm=True,
                        activation= True):
    x = L.Conv2D(filters = num_filters, kernel_size= kernel_size , strides=strides, padding='same')(inputs)
    if batch_norm:
        x = L.BatchNormalization()(x)
    if activation:
        x = L.LeakyReLU()(x)
    #print(f" x shape {x.shape}")
    return x


class Stage2Discriminator(tf.keras.Model):
  def __init__(self):
    super(Stage2Discriminator, self).__init__()
    self.embed = EmbeddingCompresssor()
    self.reshape = tf.keras.layers.Reshape(target_shape = (1, 1, 128))
    self.conv_out = tf.keras.layers.Conv2D(filters = 1, kernel_size = 4, strides = 1, padding = "valid", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))

  def call(self, inputs):
    I, E = inputs
    T = self.embed(E)
    T = self.reshape(T)
    T = tf.tile(T, (1, 4, 4, 1))
    
    x = DownSamplingBlock(I,num_filters=64,batch_norm=False)
    x = DownSamplingBlock(x,num_filters=128)
    x = DownSamplingBlock(x, num_filters=256)
    x = DownSamplingBlock(x,num_filters=512)
    x = DownSamplingBlock(x,num_filters=1024)
    x = DownSamplingBlock(x,num_filters=512)
    x = DownSamplingBlock(x,num_filters=128,kernel_size=1,strides=1)
    
    y = DownSamplingBlock(x, num_filters=128,kernel_size=1,strides=1)
    y = DownSamplingBlock(y, num_filters=256,kernel_size=3,strides=1)
    y = DownSamplingBlock(y, num_filters=128,kernel_size=3,strides=1)

    A = tf.keras.layers.Add()([x,y])
    A = tf.nn.leaky_relu(A)
    merged_input = tf.keras.layers.concatenate([A, T])
    
    z = DownSamplingBlock(merged_input,128,kernel_size=1,strides=1)
    z = self.conv_out(z)
    print(f" z {z.shape}")
    z = 
    print(f" z {z.shape} and z :{z}")

    return z


class Stage2Model(keras.Model):
    def __init__(self):
        super(Stage2Model, self).__init__()
        self.generator1 = Stage1Generator()
        self.generator2 = Stage2Generator()
        self.discriminator2 = Stage2Discriminator()
        self.generator2_optimizer = keras.optimizers.Adam(learning_rate= 0.0001, beta_1= 0.5 , beta_2= 0.999)
        self.discriminator2_optimizer = keras.optimizers.Adam(learning_rate= 0.0001, beta_1= 0.5 , beta_2= 0.999)
        self.noise_dim = 100
    def train(self, train_ds, batch_size= 64, num_epochs =1,steps_per_epoch =125):

        for epoch in range(num_epochs):
            print("Epoch %d/%d:\n  "%(epoch + 1, num_epochs), end = "")
            start_time = time.time()
            if epoch % 100 == 0:
                K.set_value(self.generator2_optimizer.learning_rate, self.generator2_optimizer.learning_rate / 2)
                K.set_value(self.discriminator2_optimizer.learning_rate, self.discriminator2_optimizer.learning_rate / 2)
            
            generator_loss_log = []
            discriminator_loss_log = []
            steps_per_epoch = steps_per_epoch
            batch_iter = iter(train_ds) 
            for i in range(steps_per_epoch):
                if i% 5 ==0:
                    print("=", end = "")

                hr_image_batch,embedding_batch = next(batch_iter)
                batch_size = hr_image_batch.shape[0]
                z_noise = tf.random.normal((batch_size, self.noise_dim))
                mismatched_images = tf.roll(hr_image_batch, shift=1, axis = 0)

                real_labels = tf.random.uniform(shape = (batch_size, ), minval = 0.9, maxval=1.0)
                fake_labels = tf.random.uniform(shape = (batch_size, ), minval = 0.0 , maxval = 0.1)
                mismatched_labels = tf.random.uniform(shape = (batch_size, ), minval = 0.0, maxval = 0.1)

                with tf.GradientTape() as generator_tape, tf.GradientTape() as discriminator_tape:
                    lr_fake_images , phi = self.generator1([embedding_batch,z_noise])
                    hr_fake_images        = self.generator2(lr_fake_images)
                    real_logits = self.discriminator2([hr_image_batch, embedding_batch])
                    fake_logits = self.discriminator2([hr_fake_images, embedding_batch])
                    mismatched_logits = self.discriminator2([mismatched_images, embedding_batch])

                    l_sup = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(real_labels,real_logits))
                    l_klreg = KL_loss(tf.random.normal((phi.shape[0], phi.shape[1])), phi)
                    generator_loss = l_sup + 2.0*l_klreg
                    
                    
                    l_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(real_labels,real_logits))
                    l_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(fake_logits,fake_labels))
                    l_mismatched = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(mismatched_logits,mismatched_labels))
                    discriminator_loss = 0.5*tf.add(l_real,tf.add(l_fake, l_mismatched))

                generator_gradients = generator_tape.gradient(generator_loss,self.generator2.trainable_variables)
                self.generator2_optimizer.apply_gradients(zip(generator_gradients, self.generator2.trainable_variables))             

                discriminator_gradients = discriminator_tape.gradient(discriminator_loss,self.discriminator2.trainable_variables)
                self.discriminator2_optimizer.apply_gradients(zip(discriminator_gradients, self.discriminator2.trainable_variables))

                generator_loss_log.append(generator_loss)
                discriminator_loss_log.append(discriminator_loss)

                end_time = time.time()

                if epoch % 1 == 0:
                    epoch_time = end_time - start_time
                    template = " - generator_loss: {:.4f} - discriminator_loss: {:.4f} - epoch_time: {:.2f} s"
                    print(template.format(tf.reduce_mean(generator_loss_log), tf.reduce_mean(discriminator_loss_log), epoch_time))

                if (epoch + 1) % 10 == 0 or epoch == num_epochs - 1:
                    save_path = "./hr_results/epoch_" + str(epoch + 1)
                    temp_embeddings = None
                    for _, embeddings in train_ds:
                        temp_embeddings = embeddings.numpy()
                        break
                    if os.path.exists(save_path) == False:
                        os.makedirs(save_path)
                    temp_batch_size = 10
                    temp_z_noise = tf.random.normal((temp_batch_size, self.noise_dim))
                    temp_embedding_batch = temp_embeddings[0:temp_batch_size]
                    fake_images, _ = self.generator2([temp_embedding_batch, temp_z_noise])
                    for i, image in enumerate(fake_images):
                            image = 127.5*image + 127.5
                            image = image.numpy().astype('uint8')
                            image  = keras.preprocessing.image.array_to_img(image)
                            image.save(save_path + "/gen_%d.png"%(i))

                    weights_path = f"./weights/hr_weights_{epoch+1}"
                    if os.path.exists(weights_path)== False:
                        os.mkdirs(weights_path)
                    self.generator2.save_weights(weights_path+"/stage2_generator.h5")
                    self.discriminator2.save_weights(weights_path+"/stage2_discriminator.h5")


model2= Stage2Model()



In [None]:
model2.train(train_ds,batch_size=4)