In [1]:
import tensorflow as tf
from tensorflow import keras 
import numpy as np
from  utils.dataset import Dataset
# dataset = Dataset(image_size=(64,64),data_base_path="./data",batch_size=32)
# train_ds = dataset.get_train_ds()

In [2]:
physical_devices = tf.config.list_physical_devices('GPU') 
tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [25]:
from re import template
from numpy.lib.type_check import imag
import tensorflow as tf
from tensorflow.keras import backend as  K 
from tensorflow import keras
from tensorflow.keras import layers as L
import time
import os
from models.stackgan import Stage1Model 
model = Stage1Model()


In [4]:
model.load_weights(path="./weights/weights_50")

In [8]:
model.train(train_ds,batch_size=32,num_epochs=100,save_weights_epoch=5)

In [7]:
dataset = Dataset(image_size=(256,256),data_base_path="./data",batch_size=2)
train_ds = dataset.get_test_ds()

In [12]:
len(dataset.test_filenames)

2933

In [23]:
def ResidualBlock(input, num_filters):
    x  =  L.Conv2D(filters = num_filters, kernel_size= 3 , strides=1, padding='same')(input)
    x  =  L.BatchNormalization()(x)
    x  =  L.ReLU()(x)
    x  =  L.Conv2D(filters = num_filters, kernel_size=3, strides=1, padding='same')(x)
    x  =  L.BatchNormalization()(x)
    x  =  L.ReLU()(x)
    return x


class Stage2Generator(keras.Model):
    def __init__(self,*args, **kwargs):
        super(Stage2Generator, self).__init__(*args, **kwargs)
        self.augmentation = ConditioningAugmentation()
        self.reshape = tf.keras.layers.Reshape(target_shape = (1, 1, 128))

    def call(self,inputs):
        image , embedding = inputs
        c,phi = self.augmentation(embedding)
        c = K.expand_dims(c, axis=1)
        c = K.expand_dims(c, axis=1)
        c = K.tile(c, [1, 16, 16, 1])
        x = DownSamplingBlock(image,num_filters=64,kernel_size = 3, strides = 1,batch_norm=False)
        x = DownSamplingBlock(x,num_filters=256)
        x = DownSamplingBlock(x,num_filters=512)
        x = K.concatenate([c, x], axis = 3)        
        x = ResidualBlock(x, 128)
        x = ResidualBlock(x, 256)
        x = ResidualBlock(x, 128)
        x = UpSamplingBlock(x,256)
        x = ResidualBlock(x,256)
        x = UpSamplingBlock(x,256)
        x = ResidualBlock(x,128)
        x = UpSamplingBlock(x,256)
        x = ResidualBlock(x,128)
        x = UpSamplingBlock(x,3)
        
        return x,phi

def DownSamplingBlock(  inputs,
                        num_filters, 
                        kernel_size= 4,
                        strides = 2,
                        batch_norm=True,
                        activation= True):
    x = L.Conv2D(filters = num_filters, kernel_size= kernel_size , strides=strides, padding='same')(inputs)
    if batch_norm:
        x = L.BatchNormalization()(x)
    if activation:
        x = L.LeakyReLU()(x)
    #print(f" x shape {x.shape}")
    return x


class Stage2Discriminator(tf.keras.Model):
  def __init__(self):
    super(Stage2Discriminator, self).__init__()
    self.embed = EmbeddingCompresssor()
    self.reshape = tf.keras.layers.Reshape(target_shape = (1, 1, 128))
    self.conv_out = tf.keras.layers.Conv2D(filters = 1, kernel_size = 4, strides = 1, padding = "valid", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))

  def call(self, inputs):
    I, E = inputs
    T = self.embed(E)
    T = self.reshape(T)
    T = tf.tile(T, (1, 4, 4, 1))
    
    x = DownSamplingBlock(I,num_filters=64,batch_norm=False)
    x = DownSamplingBlock(x,num_filters=128)
    x = DownSamplingBlock(x, num_filters=256)
    x = DownSamplingBlock(x,num_filters=512)
    x = DownSamplingBlock(x,num_filters=1024)
    x = DownSamplingBlock(x,num_filters=512)
    x = DownSamplingBlock(x,num_filters=128,kernel_size=1,strides=1)
    
    y = DownSamplingBlock(x, num_filters=128,kernel_size=1,strides=1)
    y = DownSamplingBlock(y, num_filters=256,kernel_size=3,strides=1)
    y = DownSamplingBlock(y, num_filters=128,kernel_size=3,strides=1)

    A = tf.keras.layers.Add()([x,y])
    A = tf.nn.leaky_relu(A)
    merged_input = tf.keras.layers.concatenate([A, T])
    
    z = DownSamplingBlock(merged_input,128,kernel_size=1,strides=1)
    z = self.conv_out(z)
    z = tf.squeeze(z)
    return z


class Stage2Model(keras.Model):
    def __init__(self,stage1model):
        super(Stage2Model, self).__init__()
        self.generator1 = stage1model.generator
        self.generator2 = Stage2Generator()
        self.discriminator2 = Stage2Discriminator()
        self.generator2_optimizer = keras.optimizers.Adam(learning_rate= 0.0001, beta_1= 0.5 , beta_2= 0.999)
        self.discriminator2_optimizer = keras.optimizers.Adam(learning_rate= 0.0001, beta_1= 0.5 , beta_2= 0.999)
        self.noise_dim = 100
    def train(self, train_ds, batch_size= 64, num_epochs =1,steps_per_epoch =125):

        for epoch in range(num_epochs):
            print("Epoch %d/%d:\n  "%(epoch + 1, num_epochs), end = "")
            start_time = time.time()
            if epoch % 100 == 0:
                K.set_value(self.generator2_optimizer.learning_rate, self.generator2_optimizer.learning_rate / 2)
                K.set_value(self.discriminator2_optimizer.learning_rate, self.discriminator2_optimizer.learning_rate / 2)
            
            generator_loss_log = []
            discriminator_loss_log = []
            steps_per_epoch = steps_per_epoch
            batch_iter = iter(train_ds) 
            for i in range(steps_per_epoch):
                if i% 5 ==0:
                    print("=", end = "")
                hr_image_batch,embedding_batch = next(batch_iter)
                batch_size = hr_image_batch.shape[0]
                z_noise = tf.random.normal((batch_size, self.noise_dim))
                mismatched_images = tf.roll(hr_image_batch, shift=1, axis = 0)
                real_labels = tf.random.uniform(shape = (batch_size, ), minval = 0.9, maxval=1.0)
                fake_labels = tf.random.uniform(shape = (batch_size, ), minval = 0.0 , maxval = 0.1)
                mismatched_labels = tf.random.uniform(shape = (batch_size, ), minval = 0.0, maxval = 0.1)
                with tf.GradientTape() as generator_tape, tf.GradientTape() as discriminator_tape:
                    lr_fake_images , _ = self.generator1([embedding_batch,z_noise])
                    hr_fake_images  ,  phi  = self.generator2([lr_fake_images,embedding_batch])
                    real_logits = self.discriminator2([hr_image_batch, embedding_batch])
                    del hr_image_batch  ## clear memory used by hr_image_batch 
                    fake_logits = self.discriminator2([hr_fake_images, embedding_batch])
                    del hr_fake_images  ## clear memory used by hr_fake_images 
                    mismatched_logits = self.discriminator2([mismatched_images, embedding_batch])
                    del mismatched_images  ## clear memory used by mismatched_images
                    l_sup = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(real_labels,real_logits))
                    l_klreg = KL_loss(tf.random.normal((phi.shape[0], phi.shape[1])), phi)
                    generator_loss = l_sup + 2.0*l_klreg
                    
                    
                    l_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(real_labels,real_logits))
                    l_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(fake_logits,fake_labels))
                    l_mismatched = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(mismatched_logits,mismatched_labels))
                    discriminator_loss = 0.5*tf.add(l_real,tf.add(l_fake, l_mismatched))

                generator_gradients = generator_tape.gradient(generator_loss,self.generator2.trainable_variables)
                self.generator2_optimizer.apply_gradients(zip(generator_gradients, self.generator2.trainable_variables))             
                del generator_gradients
                discriminator_gradients = discriminator_tape.gradient(discriminator_loss,self.discriminator2.trainable_variables)
                self.discriminator2_optimizer.apply_gradients(zip(discriminator_gradients, self.discriminator2.trainable_variables))
                del discriminator_gradients
                generator_loss_log.append(generator_loss)
                discriminator_loss_log.append(discriminator_loss)

                end_time = time.time()

            if epoch % 1 == 0:
                epoch_time = end_time - start_time
                template = " - generator_loss: {:.4f} - discriminator_loss: {:.4f} - epoch_time: {:.2f} s"
                print(template.format(tf.reduce_mean(generator_loss_log), tf.reduce_mean(discriminator_loss_log), epoch_time))

            if (epoch + 1) % 10 == 0 or epoch == num_epochs - 1:
                save_path = "./hr_results/epoch_" + str(epoch + 1)
                temp_embeddings = None
                for _, embeddings in train_ds:
                    temp_embeddings = embeddings.numpy()
                    break
                if os.path.exists(save_path) == False:
                    os.makedirs(save_path)
                temp_batch_size = temp_embeddings.shape[0]
                temp_z_noise = tf.random.normal((temp_batch_size, self.noise_dim))
                temp_embedding_batch = temp_embeddings[0:temp_batch_size]
                lr_temp_images,_ = self.generator1([temp_embedding_batch, temp_z_noise])
                fake_images,_= self.generator2([lr_temp_images,temp_embedding_batch])
                for i, image in enumerate(fake_images):
                        image = 127.5*image + 127.5
                        image = image.numpy().astype('uint8')
                        image  = keras.preprocessing.image.array_to_img(image)
                        image.save(save_path + "/gen_%d.png"%(i))

                weights_path = f"./weights/hr_weights_{epoch+1}"
                if os.path.exists(weights_path)== False:
                    os.makedirs(weights_path)
                self.generator2.save_weights(weights_path+"/stage2_generator.h5")
                self.discriminator2.save_weights(weights_path+"/stage2_discriminator.h5")
model2= Stage2Model(model)

In [24]:
model2.train(train_ds,batch_size=4)

Epoch 1/1:
  =

ResourceExhaustedError: OOM when allocating tensor with shape[2,128,256,256] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc [Op:Conv2DBackpropInput]