In [1]:
# Imports
import tensorflow as tf
from tensorflow import keras
from keras.preprocessing.image import ImageDataGenerator
import numpy as np
import cv2
import os
from PIL import Image

Using TensorFlow backend.


In [2]:
# Creating Generator
train_datagen = ImageDataGenerator(validation_split = 0.2)
train_generator = train_datagen.flow_from_directory(
        "/home/mihir_jain/data/", target_size = (256, 256),color_mode = "rgb", batch_size=50,class_mode = "input",subset = "training")

Found 24000 images belonging to 1 classes.


In [3]:
# Metric
def dice_coefficient(y_true, y_pred):
    y_true_flattened = keras.backend.flatten(y_true)
    y_pred_flattened = keras.backend.flatten(y_pred)
    x = keras.backend.sum(y_true_flattened * y_pred_flattened)
    y=keras.backend.sum(y_true_flattened + y_pred_flattened)
    return 2*x/y

In [4]:
# Model
class Generator():
    def prepare_model(self, input_size=(256,256,3)):
        acti_function = "relu"
        padding = "same"
        filters = 32
        kernel_size = (3,3)
        pool_size = (2,2)
        up_kernel = (2,2)
        up_stride = (2,2)

        # Inputs
        inputs = keras.layers.Input(input_size)

        # encoder 
        conv1, pooling1 = self.Convulation_layer(filters,kernel_size, pool_size, acti_function, padding, inputs)
        conv2, pooling2 = self.Convulation_layer(filters*2, kernel_size, pool_size, acti_function, padding, pooling1)
        conv3, pooling3 = self.Convulation_layer(filters*4, kernel_size, pool_size, acti_function,padding, pooling2) 
        conv4, pooling4 = self.Convulation_layer(filters*8, kernel_size, pool_size, acti_function,padding, pooling3) 
        # decoder 
        conv5, up6 = self.Up_Convulation_layer(filters*16, filters*8, kernel_size, up_kernel, up_stride, acti_function, padding, pooling4, conv4)
        conv6, up7 = self.Up_Convulation_layer(filters*8, filters*4, kernel_size, up_kernel, up_stride, acti_function,padding, up6, conv3)
        conv7, up8 = self.Up_Convulation_layer(filters*4, filters*2, kernel_size, up_kernel, up_stride, acti_function, padding, up7, conv2)
        conv8, up9 = self.Up_Convulation_layer(filters*2,filters, kernel_size, up_kernel, up_stride, acti_function,padding, up8, conv1)
        conv9 = keras.layers.Conv2D(filters=32, kernel_size=(3,3), activation='relu', padding='same')(up9)
        conv9 = keras.layers.Conv2D(filters=32, kernel_size=(3,3), activation='relu', padding='same')(conv9)
        outputs = keras.layers.Conv2D(3, (3, 3), activation='sigmoid', padding='same')(conv9)
        return keras.models.Model(inputs=[inputs], outputs=[outputs]) 
    
    def Convulation_layer(self, filters, kernel_size, pool_size, activation, padding, connecting_layer, pool_layer=True):
        conv = keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, activation=activation, padding=padding)(connecting_layer)
        conv = keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, activation=activation, padding=padding)(conv)
        pooling = keras.layers.MaxPooling2D(pool_size)(conv)
        return conv, pooling

    def Up_Convulation_layer(self, filters, up_filters, kernel_size, up_kernel, up_stride, activation, padding, connecting_layer, shared_layer):
        conv = keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, activation=activation, padding=padding)(connecting_layer)
        conv = keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, activation=activation, padding=padding)(conv)
        up = keras.layers.Conv2DTranspose(filters=up_filters, kernel_size=up_kernel, strides=up_stride, padding=padding)(conv)
        up = keras.layers.concatenate([up, shared_layer], axis=3)
        return conv, up

In [5]:
fp = open("./logs_gan.txt","a")

In [6]:
class Discriminator():
    def __init__(self):
        pass
    
    def prepare_model(self, input_size = (256,256,3)):
        activation = "relu"
        padding = "same"
        # input
        inputs = keras.layers.Input(input_size)

        # encoder
        conv1 = keras.layers.Conv2D(filters = 64,kernel_size = (4,4),strides = 2, activation = activation,padding = padding)(inputs)
        conv2 = keras.layers.Conv2D(filters = 128,kernel_size = (4,4),strides = 2, activation = activation,padding = padding)(conv1)
        conv3 = keras.layers.Conv2D(filters = 256,kernel_size = (4,4),strides = 2, activation = activation,padding = padding)(conv2)
        conv4 = keras.layers.Conv2D(filters = 512,kernel_size = (4,4),strides = 2, activation = activation,padding = padding)(conv3)

        # Flatten
        flat = keras.layers.Flatten()(conv4)

        # Fully connected
        output = keras.layers.Dense(units = 1, activation = "sigmoid")(flat)

        return keras.models.Model(inputs=[inputs], outputs=[output])

In [7]:
generator = Generator().prepare_model()
discriminator = Discriminator().prepare_model()
generator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

In [8]:
cross_entropy = tf.keras.losses.BinaryCrossentropy()

In [9]:
def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

In [10]:
def generator_loss1(fake_output):
    loss = cross_entropy(tf.ones_like(fake_output), fake_output)
    return loss

def generator_loss2(real,output):
    loss = tf.reduce_mean(tf.abs(tf.subtract(output, real)))
    return loss

def generator_loss(real, output, fake_output):
    loss1 = generator_loss1(fake_output)
    loss2 = generator_loss2(real, output)
    tf.print(float(0.01*loss1 + 0.99*loss2))
    return 0.01*loss1 + 0.99*loss2

In [11]:
@tf.function
def train_step(x, y, generator, discriminator, opt1, opt2):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(x,training = True)
        real_output = discriminator(y, training = True)
        fake_output = discriminator(generated_images, training = True)
        gen_loss = generator_loss(y,generated_images,fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)
    tf.print(dice_coefficient(y,generated_images))
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    opt1.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    opt2.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

In [12]:
# Creating masks for images
def createmask(img,mask):
    img = img.astype(np.uint8)
    masked_image = np.copy(img)
    for l in range(len(img)):
        masked_image[l] = cv2.bitwise_and(img[l], mask)
    return masked_image/255

In [14]:
import sys
import matplotlib.pyplot as plt
from tqdm import tqdm
mask = np.full((256,256,3), 255, np.uint8)

orignal = sys.stdout

# Creating a rectangular 64*64 mask
for i in range(3):
    for j in range(256):
        for k in range(256):
            if j>96 and 160>j and k>96 and 160>k:
                mask[j][k][i] = 1
epochs = 10
cnt = 0
num_images = 24000
batch_size = 50
batch_images = num_images//batch_size
for epoch in range(1,epochs):
    for x,y in train_generator:
        x_mask = createmask(x,mask).astype(np.float32)
        sys.stdout = fp
        train_step(x_mask,y/255,generator, discriminator,generator_optimizer,discriminator_optimizer)
        sys.stdout = orignal
        cnt += 1
        if (cnt == batch_images):
            break
        print(cnt,end = "\r")
    if epoch == 0:
        for x,y in train_generator:
            x_mask = createmask(x,mask).astype(np.float32)
            yhat = generator.predict(x_mask)
            y = y/255
            plt.imshow(yhat[0])
            break
    cnt = 0
    print("Epoch")
    generator.save_weights("./models_gan/d"+str(epoch)+".h5")
    discriminator.save_weights("./models_gan/g"+str(epoch)+".h5")

Epoch
Epoch
Epoch
Epoch
Epoch
Epoch
Epoch
Epoch
Epoch
