In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import tensorflow as tf
import numpy as np
import os
from PIL import Image
import time
from tqdm import tqdm

import sys
sys.path.append('..')
from input_pipeline import Pipeline

In [None]:
NUM_LABELS = 20
# modanet has 13 labels + background label = 14
# cityscapes has 19 labels + ignore label = 20

In [None]:
color_by_label = {j: np.random.randint(0, 255, size=3, dtype='uint8') for j in range(NUM_LABELS)}

def get_color_mask(sparse_mask):
    masks = []
    for j in range(14):
        m = (sparse_mask == j).astype('uint8')
        masks.append(np.expand_dims(m, 2) * color_by_label[j])
        
    return np.stack(masks).max(0)

# Build a graph

In [None]:
tf.reset_default_graph()

dataset_path = '/home/dan/datasets/cityscapes/edanet/train/'
filenames = os.listdir(dataset_path)
filenames = [n for n in filenames if n.endswith('.tfrecords')]
filenames = [os.path.join(dataset_path, n) for n in sorted(filenames)]

params = {
    'batch_size': 16, 'num_labels': NUM_LABELS,
    'image_height': 512, 'image_width': 1024
}

pipeline = Pipeline(filenames, is_training=True, params=params)
dataset = pipeline.dataset
iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
init = iterator.make_initializer(dataset)
features, labels = iterator.get_next()

# Show an image

In [None]:
with tf.Session() as sess:
    sess.run(init)
    output_images, output_masks = sess.run([features, labels])

In [None]:
i = 0
image = (255.0*output_images[i]).astype('uint8')
Image.fromarray(image)

# Show masks

In [None]:
t = Image.fromarray(image)
t.putalpha(255)

mask = get_color_mask(output_masks[i])
m = Image.fromarray(mask)
m.putalpha(Image.fromarray(150*(mask > 0).any(2).astype('uint8')))
t.alpha_composite(m)
t

# Measure speed

In [None]:
times = []
with tf.Session() as sess:
    sess.run(init)
    for _ in range(105):
        start = time.perf_counter()
        output = sess.run([features, labels])
        times.append(time.perf_counter() - start)

times = np.array(times[5:])
print(times.mean(), times.std())