In [1]:
import tensorflow as tf
from tensorflow.keras import backend as K
import numpy as np
import os
from matplotlib import pyplot as plt
import pickle
import cv2
import time

In [2]:
def KL_loss(y_true, y_pred):
  mean = y_pred[:, :128]
  logsigma = y_pred[:, 128:]
  loss = -logsigma + 0.5*(-1 + K.exp(2.0*logsigma) + K.square(mean))
  loss = K.mean(loss)
  return loss

# Data Pipeline

In [3]:
WORKERS = tf.data.experimental.AUTOTUNE

In [4]:
class Dataset():
  def __init__(self, image_size, batch_size = 64):
    data_dir = "./data/birds"
    train_dir = data_dir + "/train"
    test_dir = data_dir + "/test"
    train_embeddings_path = train_dir + "/char-CNN-RNN-embeddings.pickle"
    test_embeddings_path = test_dir + "/char-CNN-RNN-embeddings.pickle"
    train_filenames_path = train_dir + "/filenames.pickle"
    test_filenames_path = test_dir + "/filenames.pickle"
    cub_dataset_dir = "./data/CUB_200_2011"
    bounding_boxes_path = cub_dataset_dir + "/bounding_boxes.txt"
    image_ids_path = cub_dataset_dir + "/images.txt"
    images_path = cub_dataset_dir + "/images"
    
    self.image_width = image_size[0]
    self.image_height = image_size[1]
    self.batch_size = batch_size

    with open(train_filenames_path, 'rb') as f:
      self.train_filenames = pickle.load(f, encoding='latin1')
      self.train_filenames = [images_path+"/"+filename+".jpg" for filename in self.train_filenames]
    
    with open(test_filenames_path, 'rb') as f:
      self.test_filenames = pickle.load(f, encoding='latin1')
      self.test_filenames = [images_path+"/"+filename+".jpg" for filename in self.test_filenames]

    with open(train_embeddings_path, 'rb') as f:
      self.train_embeddings = pickle.load(f, encoding = 'latin1')

    with open(test_embeddings_path, 'rb') as f:
      self.test_embeddings = pickle.load(f, encoding = 'latin1')

    bounding_boxes = {}
    with open(bounding_boxes_path, 'rb') as f:
      box_coordinates = f.read()
      box_coordinates = box_coordinates.splitlines()
      box_coordinates = [box_coordinate.decode('utf-8') for box_coordinate in box_coordinates]
      for i in range(len(box_coordinates)):
        bounding_box = box_coordinates[i].split()
        bounding_boxes[bounding_box[0]] = [int(float(c)) for c in box_coordinates[i].split()][1:]

    image_ids_mapping = {}
    with open(image_ids_path, 'rb') as f:
      image_ids = f.read()
      image_ids = image_ids.splitlines()
      image_ids = [image_id.decode('utf-8') for image_id in image_ids]
      for i in range(len(image_ids)):
        image_id = image_ids[i].split()
        image_ids_mapping[image_id[0]] = image_id[1]

    bounding_boxes_mapping = {}
    for image_id in bounding_boxes.keys():
      bounding_boxes_mapping[images_path + "/" + image_ids_mapping[image_id]] = bounding_boxes[image_id]

    self.train_bounding_boxes = []
    self.test_bounding_boxes = []
    for i in range(len(self.train_filenames)):
      self.train_bounding_boxes.append(bounding_boxes_mapping[self.train_filenames[i]])
    for i in range(len(self.test_filenames)):
      self.test_bounding_boxes.append(bounding_boxes_mapping[self.test_filenames[i]])

  def crop(self, image, bounding_box):
    image = image.numpy()
    if bounding_box is not None:
      x, y, width, height = bounding_box
      image = image[y:(y + height), x:(x + width)]
      image = cv2.resize(image, (self.image_width, self.image_height))
    return image

  def parse_function(self, image_path, embeddings, bounding_box):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image, channels = 3)
    image = tf.py_function(func = self.crop, inp = [image, bounding_box], Tout = tf.float32)
    image.set_shape([self.image_width, self.image_height, 3])
    image = (image - 127.5) / 127.5

    embedding_index = np.random.randint(0, embeddings.shape[0] - 1)
    embedding = embeddings[embedding_index]
    return image, embedding
  
  def get_train_ds(self):
    BUFFER_SIZE = len(self.train_filenames)
    ds = tf.data.Dataset.from_tensor_slices((self.train_filenames, self.train_embeddings, self.train_bounding_boxes))
    ds = ds.shuffle(BUFFER_SIZE)
    ds = ds.repeat()
    ds = ds.map(self.parse_function, num_parallel_calls = WORKERS)
    ds = ds.batch(self.batch_size, drop_remainder = True)
    ds = ds.prefetch(1)
    return ds
  
  def get_test_ds(self):
    BUFFER_SIZE = len(self.test_filenames)
    ds = tf.data.Dataset.from_tensor_slices((self.test_filenames, self.test_embeddings, self.test_bounding_boxes))
    ds = ds.shuffle(BUFFER_SIZE)
    ds = ds.repeat(1)
    ds = ds.map(self.parse_function, num_parallel_calls = WORKERS)
    ds = ds.batch(self.batch_size, drop_remainder = True)
    ds = ds.prefetch(1)
    return ds

In [5]:
dataset = Dataset(image_size = (64, 64))
train_ds = dataset.get_train_ds()
#

In [11]:
test_ds = dataset.get_test_ds()
test_ds

<PrefetchDataset shapes: ((64, 64, 64, 3), (64, 1024)), types: (tf.float32, tf.float32)>

In [6]:
batch_iter = iter(train_ds)

In [9]:
batch_iter.next()[1].shape

TensorShape([1024])

In [63]:
batch_iter.next()[0].shape[0]

64

In [33]:
import pandas as pd
df = pd.DataFrame()

In [35]:
df['file_names'] = dataset.train_filenames
df['embeddings'] = dataset.train_embeddings

In [40]:
df['bounding_boxes']= dataset.train_bounding_boxes

In [41]:
df

Unnamed: 0,file_names,embeddings,bounding_boxes
0,./data/CUB_200_2011/images/002.Laysan_Albatros...,"[[-0.1984601, 0.06850037, -0.08754172, 0.20703...","[144, 40, 333, 165]"
1,./data/CUB_200_2011/images/002.Laysan_Albatros...,"[[-0.13848011, 0.1029341, 0.112114534, 0.04012...","[202, 28, 164, 340]"
2,./data/CUB_200_2011/images/002.Laysan_Albatros...,"[[0.055296864, 0.022586502, 0.40397644, 0.2609...","[72, 68, 383, 225]"
3,./data/CUB_200_2011/images/002.Laysan_Albatros...,"[[-0.008370306, 0.07140453, -0.049068503, -0.0...","[60, 128, 438, 106]"
4,./data/CUB_200_2011/images/002.Laysan_Albatros...,"[[-0.02236132, -0.025680661, 0.47425216, 0.103...","[32, 35, 259, 361]"
...,...,...,...
8850,./data/CUB_200_2011/images/200.Common_Yellowth...,"[[-0.054710753, 0.083771996, 0.20557919, 0.029...","[89, 95, 354, 250]"
8851,./data/CUB_200_2011/images/200.Common_Yellowth...,"[[-0.2009305, 0.15242675, 0.13367009, 0.108257...","[157, 62, 184, 219]"
8852,./data/CUB_200_2011/images/200.Common_Yellowth...,"[[-0.103044346, 0.028868947, 0.26856357, 0.114...","[190, 102, 198, 202]"
8853,./data/CUB_200_2011/images/200.Common_Yellowth...,"[[-0.085877456, 0.12509555, 0.13053851, 0.0840...","[3, 20, 408, 307]"


#Stage 1 StackGAN

In [16]:
class ConditioningAugmentation(tf.keras.Model):
  def __init__(self):
    super(ConditioningAugmentation, self).__init__()
    self.dense = tf.keras.layers.Dense(units = 256)

  def call(self, E):
    X = self.dense(E)
    phi = tf.nn.leaky_relu(X)
    mean = phi[:, :128]
    std = K.exp(phi[:, 128:])
    epsilon = K.random_normal(shape = K.constant((mean.shape[1], ), dtype = 'int32'))
    C = mean + epsilon*std
    return C, phi

In [17]:
class EmbeddingCompressor(tf.keras.Model):
  def __init__(self):
    super(EmbeddingCompressor, self).__init__()
    self.dense = tf.keras.layers.Dense(units = 128)

  def call(self, E):
    X = self.dense(E)
    return tf.nn.relu(X)

In [18]:
class Stage1Generator(tf.keras.Model):
  def __init__(self):
    super(Stage1Generator, self).__init__()
    self.canet = ConditioningAugmentation()
    self.concat = tf.keras.layers.Concatenate(axis = 1)
    self.dense = tf.keras.layers.Dense(units = 128*8*4*4, kernel_initializer = tf.random_normal_initializer(stddev = 0.02))
    self.reshape = tf.keras.layers.Reshape(target_shape = (4, 4, 128*8), input_shape = (128*8*4*4, ))
    self.batchnorm1 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    
    self.deconv1 = tf.keras.layers.Conv2DTranspose(filters = 512, kernel_size = 4, strides = (2, 2), padding = "same", kernel_initializer = tf.random_normal_initializer(stddev = 0.02))
    self.batchnorm2 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.deconv2 = tf.keras.layers.Conv2DTranspose(filters = 256, kernel_size = 4, strides = (2, 2), padding = "same", kernel_initializer = tf.random_normal_initializer(stddev = 0.02))
    self.batchnorm3 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.deconv3 = tf.keras.layers.Conv2DTranspose(filters = 128, kernel_size = 4, strides = (2, 2), padding = "same", kernel_initializer = tf.random_normal_initializer(stddev = 0.02))
    self.batchnorm4 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.deconv4 = tf.keras.layers.Conv2DTranspose(filters = 3, kernel_size = 4, strides = (2, 2), padding = "same", kernel_initializer = tf.random_normal_initializer(stddev = 0.02))

  def call(self, inputs):
    E, Z = inputs
    C, phi = self.canet(E)

    gen_input = self.concat([C, Z])
    X = self.dense(gen_input)
    X = self.reshape(X)
    X = self.batchnorm1(X)
    X = tf.nn.relu(X)

    X = self.deconv1(X)
    X = self.batchnorm1(X)
    X = tf.nn.relu(X)

    X = self.deconv2(X)
    X = self.batchnorm2(X)
    X = tf.nn.relu(X)

    X = self.deconv3(X)
    X = self.batchnorm3(X)
    X = tf.nn.relu(X)

    X = self.deconv4(X)
    return tf.nn.tanh(X), phi

In [19]:
class Stage1Discriminator(tf.keras.Model):
  def __init__(self):
    super(Stage1Discriminator, self).__init__()
    self.conv1 = tf.keras.layers.Conv2D(filters = 64, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.conv2 = tf.keras.layers.Conv2D(filters = 128, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm1 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv3 = tf.keras.layers.Conv2D(filters = 256, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm2 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv4 = tf.keras.layers.Conv2D(filters = 512, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm3 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.embed = EmbeddingCompressor()
    self.reshape = tf.keras.layers.Reshape(target_shape = (1, 1, 128))
    self.concat = tf.keras.layers.Concatenate()
    self.conv5 = tf.keras.layers.Conv2D(filters = 64*8, kernel_size = 1, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm4 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv6 = tf.keras.layers.Conv2D(filters = 1, kernel_size = 4, strides = 1, padding = "valid", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))

  def call(self, inputs):
    I, E = inputs
    X = self.conv1(I)
    X = tf.nn.leaky_relu(X)

    X = self.conv2(X)
    X = self.batchnorm1(X)
    X = tf.nn.leaky_relu(X)

    X = self.conv3(X)
    X = self.batchnorm2(X)
    X = tf.nn.leaky_relu(X)

    X = self.conv4(X)
    X = self.batchnorm3(X)
    X = tf.nn.leaky_relu(X)

    T = self.embed(E)
    T = self.reshape(T)
    T = tf.tile(T, (1, 4, 4, 1))
    merged_input = self.concat([X, T])

    Y = self.conv5(merged_input)
    Y = self.batchnorm4(Y)
    Y = tf.nn.leaky_relu(Y)

    Y = self.conv6(Y)
    return tf.squeeze(Y)

In [20]:
class Stage1Model(tf.keras.Model):
  def __init__(self):
    super(Stage1Model, self).__init__()
    self.stage1_generator = Stage1Generator()
    self.stage1_discriminator = Stage1Discriminator()
    
  def call(self, input):
    x = self.stage1_generator
    x = self.stage1_discriminator
    return x

  def train(self, train_ds, batch_size = 64, num_epochs = 600, z_dim = 100, c_dim = 128, stage1_generator_lr = 0.0004, stage1_discriminator_lr = 0.0004):
    generator_optimizer = tf.keras.optimizers.Adam(lr = stage1_generator_lr, beta_1 = 0.5, beta_2 = 0.999)
    discriminator_optimizer = tf.keras.optimizers.Adam(lr = stage1_discriminator_lr, beta_1 = 0.5, beta_2 = 0.999)
    
    for epoch in range(num_epochs):
      print("Epoch %d/%d:\n ["%(epoch + 1, num_epochs), end = "")
      start_time = time.time()
      if epoch % 100 == 0:
        K.set_value(generator_optimizer.learning_rate, generator_optimizer.learning_rate / 2)
        K.set_value(discriminator_optimizer.learning_rate, discriminator_optimizer.learning_rate / 2)
    
      generator_loss_log = []
      discriminator_loss_log = []
      steps_per_epoch = 125
      batch_iter = iter(train_ds)
      for i in range(steps_per_epoch):
        if i % 5 == 0:
          print("=", end = "")
        image_batch, embedding_batch = next(batch_iter)
        z_noise = tf.random.normal((batch_size, z_dim))

        mismatched_images = tf.roll(image_batch, shift = 1, axis = 0)

        real_labels = tf.random.uniform(shape = (batch_size, ), minval = 0.9, maxval = 1.0)
        fake_labels = tf.random.uniform(shape = (batch_size, ), minval = 0.0, maxval = 0.1)
        mismatched_labels = tf.random.uniform(shape = (batch_size, ), minval = 0.0, maxval = 0.1)

        with tf.GradientTape() as generator_tape, tf.GradientTape() as discriminator_tape:
          fake_images, phi = self.stage1_generator([embedding_batch, z_noise])
          
          real_logits = self.stage1_discriminator([image_batch, embedding_batch])
          fake_logits = self.stage1_discriminator([fake_images, embedding_batch])
          mismatched_logits = self.stage1_discriminator([mismatched_images, embedding_batch])
          
          l_sup = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(real_labels, fake_logits))
          l_klreg = KL_loss(tf.random.normal((phi.shape[0], phi.shape[1])), phi)
          generator_loss = l_sup + 2.0*l_klreg
          
          l_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(real_labels, real_logits))
          l_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(fake_labels, fake_logits))
          l_mismatched = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(mismatched_labels, mismatched_logits))
          discriminator_loss = 0.5*tf.add(l_real, 0.5*tf.add(l_fake, l_mismatched))
        
        generator_gradients = generator_tape.gradient(generator_loss, self.stage1_generator.trainable_variables)
        discriminator_gradients = discriminator_tape.gradient(discriminator_loss, self.stage1_discriminator.trainable_variables)
        
        generator_optimizer.apply_gradients(zip(generator_gradients, self.stage1_generator.trainable_variables))
        discriminator_optimizer.apply_gradients(zip(discriminator_gradients, self.stage1_discriminator.trainable_variables))
        
        generator_loss_log.append(generator_loss)
        discriminator_loss_log.append(discriminator_loss)

      end_time = time.time()

      if epoch % 1 == 0:
        epoch_time = end_time - start_time
        template = "] - generator_loss: {:.4f} - discriminator_loss: {:.4f} - epoch_time: {:.2f} s"
        print(template.format(tf.reduce_mean(generator_loss_log), tf.reduce_mean(discriminator_loss_log), epoch_time))

      if (epoch + 1) % 10 == 0 or epoch == num_epochs - 1:
        save_path = "./lr_results/epoch_" + str(epoch + 1)
        temp_embeddings = None
        for _, embeddings in train_ds:
          temp_embeddings = embeddings.numpy()
          break
        if os.path.exists(save_path) == False:
          os.makedirs(save_path)
        temp_batch_size = 10
        temp_z_noise = tf.random.normal((temp_batch_size, z_dim))
        temp_embedding_batch = temp_embeddings[0:temp_batch_size]
        fake_images, _ = self.stage1_generator([temp_embedding_batch, temp_z_noise])
        for i, image in enumerate(fake_images):
          image = 127.5*image + 127.5
          image = image.numpy().astype('uint8')
          cv2.imwrite(save_path + "/gen_%d.png"%(i), image)
      
        self.stage1_generator.save_weights("./weights/stage1_generator_" + str(epoch + 1) + ".ckpt")
        self.stage1_discriminator.save_weights("./weights/stage1_discriminator_" + str(epoch + 1) + ".ckpt")

    def generate_image(self, embedding):
      self.stage1_generator.compile(loss = "mse", optimizer = "adam")
      self.stage1_generator.load_weights("stage1_generator_600.ckpt").expect_partial()
      z_noise = tf.random.normal((batch_size, z_dim))
      generated_image = self.stage1_generator([embedding, z_noise])
      return generated_image

In [25]:
model = Stage1Model()

In [22]:
model.train(train_ds,num_epochs=1)

Epoch 1/1:


In [46]:
len(train_ds)

TypeError: dataset length is infinite.

#Stage 2 StackGAN

In [None]:
class ResidualBlock(tf.keras.layers.Layer):
  def __init__(self):
    super(ResidualBlock, self).__init__()
    self.conv1 = tf.keras.layers.Conv2D(filters = 128*4, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm1 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv2 = tf.keras.layers.Conv2D(filters = 128*4, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm2 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)

    def call(self, I):
      X = self.conv1(I)
      X = self.batchnorm1(X)
      X = tf.nn.relu(X)

      X = self.conv2(X)
      X = self.batchnorm2(X)
      X = tf.nn.relu(X)
      X = tf.keras.layers.Add()([X, I])
      X = tf.nn.relu(X)
      return X

In [None]:
class Stage2Generator(tf.keras.Model):
  def __init__(self):
    super(Stage2Generator, self).__init__()
    self.canet = ConditioningAugmentation()
    self.conv1 = tf.keras.layers.Conv2D(128, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.conv2 = tf.keras.layers.Conv2D(256, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm1 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv3 = tf.keras.layers.Conv2D(512, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm2 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv4 = tf.keras.layers.Conv2D(512, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm3 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.resblock1 = ResidualBlock()
    self.resblock2 = ResidualBlock()
    self.resblock3 = ResidualBlock()
    self.resblock4 = ResidualBlock()
    self.upsamp1 = tf.keras.layers.UpSampling2D(size = (2, 2))
    self.conv5 = tf.keras.layers.Conv2D(256, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm4 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.upsamp2 = tf.keras.layers.UpSampling2D(size = (2, 2))
    self.conv6 = tf.keras.layers.Conv2D(filters = 128, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm5 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.upsamp3 = tf.keras.layers.UpSampling2D(size = (2, 2))
    self.conv7 = tf.keras.layers.Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm6 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.upsamp4 = tf.keras.layers.UpSampling2D(size = (2, 2))
    self.conv8 = tf.keras.layers.Conv2D(filters = 32, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm7 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv9 = tf.keras.layers.Conv2D(filters = 3, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
  
  def call(self, inputs):
    E, I = inputs
    C, phi = self.canet(E)

    X = self.conv1(I)
    X = tf.nn.relu(X)
    
    X = self.conv2(X)
    X = self.batchnorm1(X)
    X = tf.nn.relu(X)

    X = self.conv3(X)
    X = self.batchnorm2(X)
    X = tf.nn.relu(X)

    C = K.expand_dims(C, axis = 1)
    C = K.expand_dims(C, axis = 1)
    C = K.tile(C, [1, 16, 16, 1])
    J = K.concatenate([C, X], axis = 3)

    X = self.conv4(X)
    X = self.batchnorm3(X)
    X = tf.nn.relu(X)

    X = self.resblock1(X)
    X = self.resblock2(X)
    X = self.resblock3(X)
    X = self.resblock4(X)

    X = self.upsamp1(X)
    X = self.conv5(X)
    X = self.batchnorm4(X)
    X = tf.nn.relu(X)
    
    X = self.upsamp2(X)
    X = self.conv6(X)
    X = self.batchnorm5(X)
    X = tf.nn.relu(X)
    
    X = self.upsamp3(X)
    X = self.conv7(X)
    X = self.batchnorm6(X)
    X = tf.nn.relu(X)
    
    X = self.upsamp4(X)
    X = self.conv8(X)
    X = self.batchnorm7(X)
    X = tf.nn.relu(X)
    
    X = self.conv9(X)
    return tf.nn.tanh(X), phi

In [None]:
class Stage2Discriminator(tf.keras.Model):
  def __init__(self):
    super(Stage2Discriminator, self).__init__()
    self.embed = EmbeddingCompressor()
    self.reshape = tf.keras.layers.Reshape(target_shape = (1, 1, 128))
    self.conv1 = tf.keras.layers.Conv2D(filters = 64, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.conv2 = tf.keras.layers.Conv2D(filters = 128, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm1 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv3 = tf.keras.layers.Conv2D(filters = 256, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm2 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv4 = tf.keras.layers.Conv2D(filters = 512, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm3 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv5 = tf.keras.layers.Conv2D(filters = 1024, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm4 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv6 = tf.keras.layers.Conv2D(filters = 2048, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm5 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv7 = tf.keras.layers.Conv2D(filters = 1024, kernel_size = 1, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm6 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv8 = tf.keras.layers.Conv2D(filters = 512, kernel_size = 1, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm7 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv9 = tf.keras.layers.Conv2D(filters = 128, kernel_size = 1, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm8 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv10 = tf.keras.layers.Conv2D(filters = 128, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm9 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv11 = tf.keras.layers.Conv2D(filters = 512, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm10 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv12 = tf.keras.layers.Conv2D(filters = 64*8, kernel_size = 1, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm11 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv13 = tf.keras.layers.Conv2D(filters = 1, kernel_size = 4, strides = 1, padding = "valid", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))

  def call(self, inputs):
    I, E = inputs
    T = self.embed(E)
    T = self.reshape(T)
    T = tf.tile(T, (1, 4, 4, 1))

    X = self.conv1(I)
    X = tf.nn.leaky_relu(X)

    X = self.conv2(X)
    X = self.batchnorm1(X)
    X = tf.nn.leaky_relu(X)
    
    X = self.conv3(X)
    X = self.batchnorm2(X)
    X = tf.nn.leaky_relu(X)
    
    X = self.conv4(X)
    X = self.batchnorm3(X)
    X = tf.nn.leaky_relu(X)
    
    X = self.conv5(X)
    X = self.batchnorm4(X)
    X = tf.nn.leaky_relu(X)
   
    X = self.conv6(X)
    X = self.batchnorm5(X)
    X = tf.nn.leaky_relu(X)
    
    X = self.conv7(X)
    X = self.batchnorm6(X)
    X = tf.nn.leaky_relu(X)
    
    X = self.conv8(X)
    X = self.batchnorm7(X)

    Y = self.conv9(X)
    Y = self.batchnorm8(Y)
    Y = tf.nn.leaky_relu(Y)

    Y = self.conv10(Y)
    Y = self.batchnorm9(Y)
    Y = tf.nn.leaky_relu(Y)

    Y = self.conv11(Y)
    Y = self.batchnorm10(Y)

    A = tf.keras.layers.Add()([X, Y])
    A = tf.nn.leaky_relu(A)

    merged_input = tf.keras.layers.concatenate([A, T])

    Z = self.conv12(merged_input)
    Z = self.batchnorm11(Z)
    Z = tf.nn.leaky_relu(Z)
    
    Z = self.conv13(Z)
    return tf.squeeze(Z)

In [None]:
class Stage2Model(tf.keras.Model):
  def __init__(self):
    super(Stage2Model, self).__init__()
    self.stage1_generator = Stage1Generator()
    self.stage1_generator.compile(loss = "mse", optimizer = "adam")
    self.stage1_generator.load_weights("./lr_model_checkpoints/stage1_generator_600.ckpt").expect_partial()
    
    self.stage2_generator = Stage2Generator()
    self.stage2_discriminator = Stage2Discriminator()
    self.stage2_generator.compile(loss = "mse", optimizer = "adam")
    self.stage2_generator.load_weights("./hr_model_checkpoints/stage2_generator_170.ckpt").expect_partial()
    self.stage2_discriminator.compile(loss = "mse", optimizer = "adam")
    self.stage2_discriminator.load_weights("./hr_model_checkpoints/stage2_discriminator_170.ckpt").expect_partial()
    
  def train(self, train_ds, batch_size = 64, num_epochs = 1200, z_dim = 100, stage1_generator_lr = 0.0001, stage1_discriminator_lr = 0.0001):
    generator_optimizer = tf.keras.optimizers.Adam(lr = stage1_generator_lr, beta_1 = 0.5, beta_2 = 0.999)
    discriminator_optimizer = tf.keras.optimizers.Adam(lr = stage1_discriminator_lr, beta_1 = 0.5, beta_2 = 0.999)
    
    for epoch in range(num_epochs):
      print("Epoch %d/%d:\n ["%(epoch + 1, num_epochs), end = "")
      start_time = time.time()
      if epoch % 100 == 0:
        K.set_value(generator_optimizer.learning_rate, generator_optimizer.learning_rate / 2)
        K.set_value(discriminator_optimizer.learning_rate, discriminator_optimizer.learning_rate / 2)
    
      generator_loss_log = []
      discriminator_loss_log = []
      steps_per_epoch = 125
      batch_iter = iter(train_ds)
      for i in range(steps_per_epoch):
        if i % 5 == 0:
          print("=", end = "")
        hr_image_batch, embedding_batch = next(batch_iter)
        z_noise = tf.random.normal((batch_size, z_dim))

        mismatched_images = tf.roll(hr_image_batch, shift = 1, axis = 0)

        real_labels = tf.random.uniform(shape = (batch_size, ), minval = 0.9, maxval = 1.0)
        fake_labels = tf.random.uniform(shape = (batch_size, ), minval = 0.0, maxval = 0.1)
        mismatched_labels = tf.random.uniform(shape = (batch_size, ), minval = 0.0, maxval = 0.1)

        with tf.GradientTape() as generator_tape, tf.GradientTape() as discriminator_tape:
          lr_fake_images, _ = self.stage1_generator([embedding_batch, z_noise])
          hr_fake_images, phi = self.stage2_generator([embedding_batch, lr_fake_images])
          real_logits = self.stage2_discriminator([hr_image_batch, embedding_batch])
          fake_logits = self.stage2_discriminator([hr_fake_images, embedding_batch])
          mismatched_logits = self.stage2_discriminator([mismatched_images, embedding_batch])
          
          l_sup = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(real_labels, fake_logits))
          l_klreg = KL_loss(tf.random.normal((phi.shape[0], phi.shape[1])), phi)
          generator_loss = l_sup + 2.0*l_klreg
          
          l_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(real_labels, real_logits))
          l_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(fake_labels, fake_logits))
          l_mismatched = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(mismatched_labels, mismatched_logits))
          discriminator_loss = 0.5*tf.add(l_real, 0.5*tf.add(l_fake, l_mismatched))
        
        generator_gradients = generator_tape.gradient(generator_loss, self.stage2_generator.trainable_variables)
        discriminator_gradients = discriminator_tape.gradient(discriminator_loss, self.stage2_discriminator.trainable_variables)
        
        generator_optimizer.apply_gradients(zip(generator_gradients, self.stage2_generator.trainable_variables))
        discriminator_optimizer.apply_gradients(zip(discriminator_gradients, self.stage2_discriminator.trainable_variables))
        
        generator_loss_log.append(generator_loss)
        discriminator_loss_log.append(discriminator_loss)
        
      end_time = time.time()

      if epoch % 1 == 0:
        epoch_time = end_time - start_time
        template = "] - generator_loss: {:.4f} - discriminator_loss: {:.4f} - epoch_time: {:.2f} s"
        print(template.format(tf.reduce_mean(generator_loss_log), tf.reduce_mean(discriminator_loss_log), epoch_time))

      if (epoch + 1) % 10 == 0 or epoch == num_epochs - 1:
        save_path = "./hr_results/epoch_" + str(epoch + 1)
        temp_embeddings = None
        for _, embeddings in train_ds:
          temp_embeddings = embeddings.numpy()
          break
        if os.path.exists(save_path) == False:
          os.makedirs(save_path)
        temp_batch_size = 10
        temp_z_noise = tf.random.normal((temp_batch_size, z_dim))
        temp_embedding_batch = temp_embeddings[0:temp_batch_size]
        fake_images, _ = self.stage1_generator([temp_embedding_batch, temp_z_noise])
        for i, image in enumerate(fake_images):
          image = 127.5*image + 127.5
          image = image.numpy().astype('uint8')
          cv2.imwrite(save_path + "/gen_%d.png"%(i), image)
        self.stage2_generator.save_weights("./hr_model_checkpoints/stage2_generator_" + str(epoch + 1) + ".ckpt")
        self.stage2_discriminator.save_weights("./hr_model_checkpoints/stage2_discriminator_" + str(epoch + 1) + ".ckpt")

In [None]:
hr_dataset = Dataset(image_size = (256, 256), batch_size = 64)
train_ds = hr_dataset.get_train_ds()
test_ds = hr_dataset.get_test_ds()

In [None]:
stage2_model = Stage2Model()
stage2_model.train(train_ds)