<a href="https://colab.research.google.com/github/anassboussarhan/alphamatesegmentation/blob/master/alphamatte.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import tensorflow as tf
from tensorflow.python.lib.io import file_io
import os

In [0]:
from google.colab import drive
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


FUNCTIONS

In [0]:
def conv(x, name, filters, kernel_size=3, strides=1, dilation=1):
    with tf.variable_scope(name):
        x = tf.layers.conv2d(x, filters=filters, kernel_size=kernel_size, strides=strides,
                             dilation_rate=dilation, padding="same")
    return x


def instance_norm(x, name, epsilon=1e-5):
    with tf.variable_scope(name):
        gamma = tf.get_variable(initializer=tf.ones([x.shape[-1]]), name="gamma")
        beta = tf.get_variable(initializer=tf.zeros([x.shape[-1]]), name="beta")
        mean, var = tf.nn.moments(x, axes=[1,2], keep_dims=True)
        x = tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name="norm",)
    return x

In [0]:
def loss_fun(images, gt_masks, alpha_mattes, epsilon=1e-6):
    la = tf.reduce_sum(tf.sqrt(tf.square(gt_masks - alpha_mattes) + epsilon))
    lcolor = tf.reduce_sum(tf.sqrt(tf.square(tf.tile(gt_masks, multiples=(1,1,1,3)) * images
                                             - tf.tile(alpha_mattes, multiples=(1,1,1,3)) * images) + epsilon))
    return la + lcolor

In [0]:
def iou(real_B, fake_B):
    real_B_ones = tf.greater_equal(real_B, 0.5)
    fake_B_ones = tf.greater_equal(fake_B, 0.5)
    i = tf.cast(tf.logical_and(real_B_ones, fake_B_ones), dtype=tf.float32)
    u = tf.cast(tf.logical_or(real_B_ones, fake_B_ones), dtype=tf.float32)
    iou = tf.reduce_mean(tf.reduce_sum(i, axis=[1, 2, 3]) / tf.reduce_sum(u, axis=[1, 2, 3]))
    return iou

In [0]:
def segmentation_block(x):
    x_shape = tf.shape(x)
    out_w, out_h = x_shape[1], x_shape[2]
    with tf.variable_scope("segmentation_block", reuse=tf.AUTO_REUSE):
        conv1 = conv(x, name="conv1", filters=13, strides=2)
        pool1 = tf.layers.max_pooling2d(x, pool_size=2, strides=2)
        conv1_concat = tf.concat([conv1, pool1], axis=3)
        conv2 = tf.nn.relu(conv(conv1_concat, name="conv2", filters=16, dilation=2))
        conv2_concat = tf.concat([conv1_concat, conv2], axis=3)
        conv3 = tf.nn.relu(conv(conv2_concat, name="conv3", filters=16, dilation=4))
        conv3_concat = tf.concat([conv2_concat, conv3], axis=3)
        conv4 = tf.nn.relu(conv(conv3_concat, name="conv4", filters=16, dilation=6))
        conv4_concat = tf.concat([conv3_concat, conv4], axis=3)
        conv5 = tf.nn.relu(conv(conv4_concat, name="conv5", filters=16, dilation=8))
        conv5_concat = tf.concat([conv2, conv3, conv4, conv5], axis=3)
        conv6 = tf.nn.relu(conv(conv5_concat, name="conv6", filters=2))
        pred = tf.image.resize_images(conv6, size=[out_w, out_h])
    return pred

In [0]:
def feathering_block(x, coarse_mask):
    with tf.variable_scope("feathering_block", reuse=tf.AUTO_REUSE):
        foreground, background = tf.split(coarse_mask, axis=3, num_or_size_splits=2)
        x_square = tf.square(x)
        x_masked = x * tf.tile(foreground, multiples=(1,1,1,3))

        x = tf.concat([x, coarse_mask, x_square, x_masked], axis=3)

        conv1 = tf.nn.relu(instance_norm(conv(x, name="conv1", filters=32), name="norm1"))
        conv4 = conv(conv1, name="conv4", filters=3)

        a, b, c = tf.split(conv4, axis=3, num_or_size_splits=3)

        output = a * foreground + b * background + c
    output = tf.nn.sigmoid(output)
    return output

In [0]:
def _extract_features(example):
    features = {
        "image": tf.FixedLenFeature((), tf.string),
        "mask": tf.FixedLenFeature((), tf.string)
    }
    parsed_example = tf.parse_single_example(example, features)
    images = tf.image.decode_png(parsed_example["image"],channels=3, dtype=tf.uint8)
    images = tf.image.resize_images(images,[800, 600])
    images=tf.reshape(images,[800,600,3])
    masks = tf.image.decode_png(parsed_example["mask"],channels=1, dtype=tf.uint8)
    masks = tf.image.resize_images(masks,[800, 600])
    masks=tf.reshape(masks,[800,600,1])
    return images, masks


def create_one_shot_iterator(filenames, batch_size, num_epoch):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(_extract_features)
    dataset = dataset.shuffle(buffer_size=batch_size)
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat(num_epoch)
    return dataset.make_one_shot_iterator()


def create_initializable_iterator(filenames, batch_size):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(_extract_features)
    dataset = dataset.shuffle(buffer_size=batch_size)
    dataset = dataset.batch(batch_size)
    return dataset.make_initializable_iterator()


def augment_dataset(images, masks, size=None, augment=True):
    if augment:
        cond_flip_lr = tf.cast(tf.random_uniform([], maxval=2, dtype=tf.int32), tf.bool)
        cond_rotate = tf.random_uniform([], minval=-1/6, maxval=1/6)

        def orig(images, masks):
            return images, masks

        def flip(images, masks):
            return tf.map_fn(tf.image.flip_left_right, images), tf.map_fn(tf.image.flip_left_right, masks)

        images, masks = tf.cond(cond_flip_lr, lambda: flip(images, masks), lambda: orig(images, masks))

        images, masks = tf.contrib.image.rotate(images, angles=cond_rotate), tf.contrib.image.rotate(masks, angles=cond_rotate)

    if size is not None:
        images = tf.image.resize_images(images, 1 * size)
        masks = tf.image.resize_images(masks, 1 * size)

    return images, masks

In [0]:
mode=1
train_files="/content/drive/My Drive/Projetand/train-00001-of-00001"
test_files="/content/drive/My Drive/Projetand/test-00001-of-00001"
log_dir='/content/drive/My Drive/Projetand/log/'
ckpt_dir='/content/drive/My Drive/Projetand/cpkt/'
train_batch_size=20
test_batch_size=2
num_epochs=20000
learning_rate=1e-3
resume=None   



   


In [0]:
 
if mode is None or mode <= 0 or mode > 3:
        raise Exception("Invalid mode")


    


num_train_samples = sum(1 for f in file_io.get_matching_files(train_files) 
                            for n in tf.python_io.tf_record_iterator(f))

num_test_samples = sum(1 for f in file_io.get_matching_files(test_files) 
                           for n in tf.python_io.tf_record_iterator(f))

   

train_iterator = create_one_shot_iterator(train_files, train_batch_size, num_epoch=num_epochs)
test_iterator = create_initializable_iterator(test_files, batch_size=num_test_samples)

next_images, next_masks = train_iterator.get_next()
next_images, next_masks = augment_dataset(next_images, next_masks)
coarse_masks = segmentation_block(next_images)
alpha_mattes = feathering_block(next_images, coarse_masks)
loss = loss_fun(next_images, next_masks, alpha_mattes)

test_images, test_masks = test_iterator.get_next()
test_images, test_masks = augment_dataset(test_images, test_masks,augment=False)
test_coarse_masks = segmentation_block(test_images)
test_alpha_mattes = feathering_block(test_images, test_coarse_masks)
test_loss = loss_fun(test_images, test_masks, test_alpha_mattes)

train_iou = iou(next_masks, alpha_mattes)
test_iou = iou(test_masks, test_alpha_mattes)

all_trainable_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

train_op = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss, var_list=all_trainable_vars)

summary = tf.summary.FileWriter(logdir=log_dir)
image_summary = tf.summary.image("image", next_images)
gt_summary = tf.summary.image("gt", next_masks * next_images)
result_summary = tf.summary.image("result", alpha_mattes * next_images)
images_summary = tf.summary.merge([image_summary, gt_summary, result_summary])

test_image_summary = tf.summary.image("test_image", test_images)
test_gt_summary = tf.summary.image("test_gt", test_masks * test_images)
test_result_summary = tf.summary.image("test_result", test_alpha_mattes * test_images)
test_images_summary = tf.summary.merge([test_image_summary, test_gt_summary, test_result_summary])

loss_summary = tf.summary.scalar("loss", loss)
test_loss_summary = tf.summary.scalar("test_loss", test_loss)

train_iou_sum = tf.summary.scalar("train_iou", train_iou)
test_iou_sum = tf.summary.scalar("test_iou", test_iou)

saver = tf.train.Saver(var_list=tf.trainable_variables())

  
     

def get_session(sess):
    session = sess
    while type(session).__name__ != 'Session':
        session = session._sess
    return session


with tf.train.MonitoredTrainingSession() as sess:
 
    it = 0
    if resume is not None and resume > 0:
        saver.restore(sess, os.path.join(ckpt_dir, "ckpt") + "-{it}".format(it=resume))
        it = resume + 1

    for it in range(300):
        _, cur_loss, cur_images_summary, cur_loss_summary, cur_train_iou = sess.run([train_op, loss, images_summary, loss_summary, train_iou_sum])
        summary.add_summary(cur_loss_summary, it)
        summary.add_summary(cur_train_iou, it)

        if it % 10 == 0:
            summary.add_summary(cur_images_summary, it)
            print("check2")

      

        if it % 200 == 0:
            ckpt_path = saver.save(get_session(sess), save_path=os.path.join(ckpt_dir),
                                       write_meta_graph=False, global_step=it)
            print("Checkpoint saved as: {ckpt_path}".format(ckpt_path=ckpt_path))
            print("check4")

        it += 1



    ckpt_path = saver.save(get_session(sess), save_path=os.path.join(ckpt_dir), write_meta_graph=False,
                               global_step=it)
    print("Checkpoint saved as: {ckpt_path}".format(ckpt_path=ckpt_path))


INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
