In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time
import random

from IPython import display
from skimage.draw import random_shapes, rectangle, polygon, circle

tf.config.gpu.set_per_process_memory_growth(enabled=True)

In [None]:
SHAPE_SIZE = 128
CHANNELS = 3
IMAGE_COUNT_STAGE_ONE = 300
IMAGE_COUNT_STAGE_TWO = 600
NUM_CLASSES = 1 + 2 # BG + Classes
MAX_SHAPES = 3
BATCH_SIZE = 32
EPOCHS_STAGE_ONE = 75
EPOCHS_STAGE_TWO = 75
NUM_EXAMPLES_TO_GENERATE = 8
USE_FASTSCNN = False
COLOR_TRIANGLE = 33
COLOR_RECTANGLE = 66
COLOR_CIRCLE = 99

In [None]:
def generate_image_mask(size=(448, 448), max_shapes=3, show_result=False):
    masks = []
    triangles = []
    rectangles = []
    while(len(masks) == 0):
        image, labels = random_shapes(size, min_shapes=1, max_shapes=max_shapes,
                                 min_size=size[0]/4, allow_overlap=False, multichannel=(CHANNELS == 3),
                                     num_trials = 10)#, shape='triangle'
        images = []
        # Generate individual masks    
        for i in range(0, len(labels)):
            msk= np.zeros(size, dtype=np.uint8)
            img= np.zeros(size, dtype=np.uint8)

            if(labels[i][0] == 'rectangle'):
                rr, cc = rectangle((labels[i][1][0][0], labels[i][1][1][0]), (labels[i][1][0][1], labels[i][1][1][1]),
                           shape=msk.shape)
                img[rr, cc] = COLOR_RECTANGLE
                images.append(img)
                msk[rr, cc] = 1
                rectangles.append(msk)
                masks.append(msk)
            if(labels[i][0] == 'circle'):
                y = labels[i][1][0][1]- (labels[i][1][0][1] - labels[i][1][0][0]) / 2
                x = labels[i][1][1][1]-(labels[i][1][1][1] - labels[i][1][1][0]) / 2
                r = (labels[i][1][0][1] - labels[i][1][0][0]) / 2        
                rr, cc = circle(y, x, r, shape=msk.shape)
                
                img[rr, cc] = COLOR_CIRCLE
                images.append(img)
            if(labels[i][0] == 'triangle'):
                x = (labels[i][1][1][0], labels[i][1][1][1] - (labels[i][1][1][1] - labels[i][1][1][0]) / 2, labels[i][1][1][1], labels[i][1][1][0])
                y = (labels[i][1][0][1], labels[i][1][0][0], labels[i][1][0][1], labels[i][1][0][1])
                rr, cc = polygon(y, x, shape=msk.shape) 
                img[rr, cc] = COLOR_TRIANGLE
                images.append(img)
                msk[rr, cc] = 1
                triangles.append(msk)
                masks.append(msk)

        if(len(masks) == 0):
            continue
        # Merge the masks
        mask = np.zeros(size, dtype=np.uint8)
        for i in range(0, len(masks)):
            mask = np.add(mask, masks[i])
        mask = np.clip(mask,0,1)
        background = 1 - mask
        
        #image = np.zeros(size, dtype=np.uint8)
        #for i in range(0, len(images)):
        #    image = np.add(image, images[i])
        #image = 255 - image
        
        triangle = np.zeros(size, dtype=np.uint8)
        for i in range(0, len(triangles)):
            triangle = np.add(triangle, triangles[i])
        triangle = np.clip(triangle,0,1)
        
        mrectangle = np.zeros(size, dtype=np.uint8)
        for i in range(0, len(rectangles)):
            mrectangle = np.add(mrectangle, rectangles[i])
        mrectangle = np.clip(mrectangle,0,1)
        
        final_mask = np.dstack((background, triangle, mrectangle))

        if(show_result):
            print(labels)
            fig=plt.figure(figsize=(8, 8))
            fig.add_subplot(1, len(labels)+1, 1)
            plt.imshow(image)
            for i in range(0, len(labels)):    
                fig.add_subplot(1, len(labels)+1, i+2)
                plt.imshow(masks[i], cmap="Greys")
            plt.show()
        return image, final_mask

In [None]:
def preprocess_image(raw_image):
    t = tf.convert_to_tensor(raw_image, np.float32)
    #print(t.shape)
    if(CHANNELS == 1):
        t = tf.reshape(t, (SHAPE_SIZE, SHAPE_SIZE, CHANNELS,))
    #    t = (t - 127.5) / 127.5
    #else: t = t / 255.0
    t = t / 255.0
    return t

def preprocess_mask(raw_mask):
    t = tf.convert_to_tensor(raw_mask, np.float32)
    t = tf.reshape(t, (SHAPE_SIZE, SHAPE_SIZE, NUM_CLASSES,))
    #t = (t - 127.5) / 127.5
    return t

In [None]:
def generate_dataset(image_count):
    images = []
    masks = []
    for i in range(0, image_count):
        image, img_mask = generate_image_mask(size=(SHAPE_SIZE, SHAPE_SIZE), max_shapes=MAX_SHAPES)
        images.append(preprocess_image(image))
        masks.append(preprocess_mask(img_mask))
        #masks.append(img_mask)
        if((i / image_count * 100.0) % 10 == 0):
            print(i)
    # Display a sample image
    print(images[0].shape)
    if(CHANNELS == 1):
        plt.imshow(images[0][:, :, 0], cmap='gray')
    else : plt.imshow(images[0])
    return images, masks

In [None]:
images_s1, masks_s1 = generate_dataset(IMAGE_COUNT_STAGE_ONE)
img_ds_s1 = tf.data.Dataset.from_tensor_slices(images_s1)
msk_ds_s1 = tf.data.Dataset.from_tensor_slices(masks_s1)
img_msk_ds_s1 = tf.data.Dataset.zip((img_ds_s1, msk_ds_s1))

In [None]:
images_s2, masks_s2 = generate_dataset(IMAGE_COUNT_STAGE_TWO)
img_ds_s2 = tf.data.Dataset.from_tensor_slices(images_s2)
msk_ds_s2 = tf.data.Dataset.from_tensor_slices(masks_s2)
img_msk_ds_s2 = tf.data.Dataset.zip((img_ds_s2, msk_ds_s2))

In [None]:
# Setting a shuffle buffer size as large as the dataset ensures that the data is
# completely shuffled.
ds1 = img_msk_ds_s1.shuffle(buffer_size=IMAGE_COUNT_STAGE_ONE)
#ds = ds.repeat()
ds1 = ds1.batch(BATCH_SIZE)
# `prefetch` lets the dataset fetch batches, in the background while the model is training.
ds1 = ds1.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
ds1

In [None]:
# Setting a shuffle buffer size as large as the dataset ensures that the data is
# completely shuffled.
ds2 = img_msk_ds_s2.shuffle(buffer_size=IMAGE_COUNT_STAGE_TWO)
#ds = ds.repeat()
ds2 = ds2.batch(BATCH_SIZE)
# `prefetch` lets the dataset fetch batches, in the background while the model is training.
ds2 = ds2.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
ds2

In [None]:
def make_generator_model():
    c = 32
    model = tf.keras.Sequential()
    max_depth = SHAPE_SIZE // c
    true_depth = 1
    print('Max depth:{}'.format(max_depth))
    # DOWN    
    for i in range(1, max_depth+1):
        print(i)
        cc = int(c * (i * 2))
        model.add(layers.Conv2D(cc, (2, 2), strides=(1, 1), padding='same', use_bias=False, input_shape=(SHAPE_SIZE,SHAPE_SIZE,CHANNELS)))
        model.add(layers.BatchNormalization())
        model.add(layers.LeakyReLU())

        model.add(layers.MaxPool2D())
        true_depth = true_depth + 1
        
        # Making sure we don't end up into negative dimensions
        if(model.output_shape[1] == 1):
            print('True depth:{}'.format(true_depth))
            break
            
    
    
    # UP    
    for i in range(true_depth, 1, -1):
        cc = int(c * (i * 2))
        model.add(layers.Conv2DTranspose(cc, (3, 3), strides=(2, 2), padding='same', use_bias=False))
        model.add(layers.BatchNormalization())
        model.add(layers.LeakyReLU())
        
        if(model.output_shape[1] == SHAPE_SIZE):
            break

    model.add(layers.Conv2D(NUM_CLASSES, (1, 1), strides=(1, 1), padding='same', use_bias=False))
    model.add(layers.Softmax())

    return model

In [None]:
if not USE_FASTSCNN:
    generator = make_generator_model()
else:
    import fastscnn
    generator = fastscnn.build((SHAPE_SIZE, SHAPE_SIZE, CHANNELS))

noise = tf.random.normal([1, SHAPE_SIZE, SHAPE_SIZE, CHANNELS])
generated_image = generator(noise, training=False)
print(generated_image.shape)

if(CHANNELS == 1):
    plt.imshow(generated_image[0, :, :, 1], cmap='gray')
else: plt.imshow(generated_image[0, :, :, 1])

In [None]:
generated_image2 = generator(tf.reshape(tf.convert_to_tensor(images_s1[0]), (1, SHAPE_SIZE, SHAPE_SIZE, CHANNELS)), training=False)
print(generated_image2.shape)
if(CHANNELS == 1):
    plt.imshow(generated_image2[0, :, :, 1], cmap='gray')
else: plt.imshow(generated_image2[0, :, :, 1])
#generator.summary()

In [None]:
def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[SHAPE_SIZE, SHAPE_SIZE, NUM_CLASSES]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

In [None]:
def make_discriminator_model2():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[SHAPE_SIZE, SHAPE_SIZE, NUM_CLASSES]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))
    #model.add(layers.Reshape((1,3)))

    return model

In [None]:
discriminator = make_discriminator_model2()
timage = tf.reshape(tf.convert_to_tensor(images_s1[0]), (1, SHAPE_SIZE, SHAPE_SIZE, CHANNELS))
tfinal = tf.concat([timage, generated_image], -1)
print(tfinal.shape)
decision = discriminator(generated_image)
print(decision.shape)
print (np.unique(decision))

In [None]:
# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss
    #return cce(real_output, fake_output)

def generator_loss(disc_output):
    return cross_entropy(tf.ones_like(disc_output), disc_output)
    
def mask_loss(truth, gen_output):
    return cce(truth, gen_output)

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

In [None]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

In [None]:
noise_dim = SHAPE_SIZE
# We will reuse this seed overtime (so it's easier)
# to visualize progress in the animated GIF)
#seed = tf.random.normal([NUM_EXAMPLES_TO_GENERATE, noise_dim, noise_dim, 1])
seed = []
seed_truth = []
for i in range(0, NUM_EXAMPLES_TO_GENERATE):
    img, truth = generate_image_mask(size=(SHAPE_SIZE, SHAPE_SIZE), max_shapes=MAX_SHAPES)
    seed.append(preprocess_image(img))
    seed_truth.append(preprocess_mask(truth))
seed = tf.convert_to_tensor(seed)

In [None]:
# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images, masks, stage):
    #noise = tf.random.normal([BATCH_SIZE, noise_dim, noise_dim, 1])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(images, training=True)

        #real_output = discriminator(tf.concat([images, masks], -1), training=True)
        #fake_output = discriminator(tf.concat([images, generated_images], -1), training=True)
        real_output = discriminator(masks, training=True)
        fake_output = discriminator(generated_images, training=True)
        
        if stage == '1':
            gen_loss = mask_loss(masks, generated_images)
        else:
            gen_loss = generator_loss(fake_output)
            
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    
    return gen_loss, disc_loss

In [None]:
def train(dataset, epochs, stage):
    for epoch in range(epochs):
        start = time.time()

        for image_batch, mask_batch in dataset:
            gen_loss, disc_loss = train_step(image_batch, mask_batch, stage)

        # Produce images for the GIF as we go
        display.clear_output(wait=True)
        generate_and_save_images(generator,
                                 epoch + 1,
                                 stage,
                                 seed)

        # Save the model every 15 epochs
        if (epoch + 1) % 15 == 0:
            checkpoint.save(file_prefix = checkpoint_prefix + "_" + stage)

        print('Stage: {}'.format(stage))
        print('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
        print('Generator loss is {}'.format(gen_loss))
        print('Discriminator loss is {}'.format(disc_loss))

    # Generate after the final epoch
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                           epochs,
                           stage,
                           seed)

In [None]:
def generate_and_save_images(model, epoch, stage, test_input):
    # Notice `training` is set to False.
    # This is so all layers run in inference mode (batchnorm).
    predictions = model(test_input, training=False)

    fig = plt.figure(figsize=(NUM_EXAMPLES_TO_GENERATE,NUM_EXAMPLES_TO_GENERATE))
    columns = 3 + NUM_CLASSES
    offset = 1
    for i in range(predictions.shape[0]):
        plt.subplot(NUM_EXAMPLES_TO_GENERATE, columns, i+offset)
        plt.imshow(predictions[i, :, :, 0], cmap='gray')
        plt.axis('off')
        offset = offset + 1
        plt.subplot(NUM_EXAMPLES_TO_GENERATE, columns, i+offset)
        plt.imshow(predictions[i, :, :, 1], cmap='gray')
        plt.axis('off')
        offset = offset + 1
        plt.subplot(NUM_EXAMPLES_TO_GENERATE, columns, i+offset)
        plt.imshow(predictions[i, :, :, 2], cmap='gray')
        plt.axis('off')
        offset = offset + 1
        plt.subplot(NUM_EXAMPLES_TO_GENERATE, columns, i+offset)
        plt.imshow(seed_truth[i][:, :, 1], cmap='gray')
        plt.axis('off')
        offset = offset + 1
        plt.subplot(NUM_EXAMPLES_TO_GENERATE, columns, i+offset)
        plt.imshow(seed_truth[i][:, :, 2], cmap='gray')
        plt.axis('off')
        offset = offset + 1
        plt.subplot(NUM_EXAMPLES_TO_GENERATE, columns, i+offset)
        if(CHANNELS == 1):
            plt.imshow(seed[i][:, :, 0], cmap='gray')
        else: plt.imshow(seed[i])
        plt.axis('off')

    plt.savefig('image_at_epoch_s{}_{:04d}.png'.format(stage, epoch))
    plt.show()

In [None]:
%%time
train(ds1, EPOCHS_STAGE_ONE, '1')

In [None]:
%%time
train(ds2, EPOCHS_STAGE_TWO, '2')

In [None]:
anim_file = 'seg-shapes.gif'

with imageio.get_writer(anim_file, mode='I') as writer:
    filenames = glob.glob('image*.png')
    filenames = sorted(filenames)
    last = -1
    for i,filename in enumerate(filenames):
        #frame = 2*(i**0.5)
        #if round(frame) > round(last):
        #  last = frame
        #else:«
        #  continue
        image = imageio.imread(filename)
        writer.append_data(image)
    image = imageio.imread(filename)
    writer.append_data(image)

import IPython
#if IPython.version_info > (6,2,0,''):
display.Image(filename=anim_file)