In [None]:
%matplotlib inline
import importlib
import utils2; importlib.reload(utils2)
from utils2 import *

In [None]:
path = './'

In [None]:
images_path = path + 'images/'
labels_path = path + 'labels/'

In [None]:
images = glob.glob(images_path + '*.png')
print (images[:3])

labels = [labels_path + os.path.basename(i)[:-4] + '_L.png' for i in images]
print (labels[:3])

In [None]:
image_size = (480,360)

In [None]:
def open_image(img):
    return np.array(Image.open(img).resize(image_size, Image.NEAREST))

In [None]:
imgs = np.stack([open_image(i) for i in images])
imgs.shape

In [None]:
lbs = np.stack([open_image(l) for l in labels])
lbs.shape

In [None]:
imgs = imgs/255.

In [None]:
mean = imgs.mean() 
std = imgs.std()
mean, std

In [None]:
imgs-=imgs.mean()
imgs/=imgs.std()

In [None]:
# save_array('imgs.bc', imgs)
# save_array('lbs.bc', labels)

In [None]:
class BatchIndices(object):
    def __init__(self, n, bs, shuffle=False):
        self.n,self.bs,self.shuffle = n,bs,shuffle
        self.lock = threading.Lock()
        self.reset()

    def reset(self):
        self.idxs = (np.random.permutation(self.n) 
                     if self.shuffle else np.arange(0, self.n))
        self.curr = 0

    def __next__(self):
        with self.lock:
            if self.curr >= self.n: self.reset()
            ni = min(self.bs, self.n-self.curr)
            res = self.idxs[self.curr:self.curr+ni]
            self.curr += ni
            return res

In [None]:
bi = BatchIndices(10,3, True)
[next(bi) for o in range(5)]

In [None]:
class segm_generator(object):
    def __init__(self, x, y, bs=64, out_sz=(224,224), train=True):
        self.x, self.y, self.bs, self.train = x,y,bs,train
        self.n, self.ri, self.ci, _ = x.shape
        self.idx_gen = BatchIndices(self.n, bs, train)
        self.ro, self.co = out_sz
        self.ych = self.y.shape[-1] if len(y.shape)==4 else 1

    def get_slice(self, i,o):
        start = random.randint(0, i-o) if self.train else (i-o)
        return slice(start, start+o)

    def get_item(self, idx):
        slice_r = self.get_slice(self.ri, self.ro)
        slice_c = self.get_slice(self.ci, self.co)
        x = self.x[idx, slice_r, slice_c]
        y = self.y[idx, slice_r, slice_c]
        if self.train and (random.random()>0.5): 
            y = y[:,::-1]
            x = x[:,::-1]
        return x, y

    def __next__(self):
        idxs = next(self.idx_gen)
        items = (self.get_item(idx) for idx in idxs)
        xs,ys = zip(*items)
        return np.stack(xs), np.stack(ys).reshape(len(ys), -1, self.ych)

In [None]:
sg = segm_generator(imgs, lbs, 4, train=False)
b_img, b_label = next(sg)

In [None]:
plt.imshow(b_img[0]*0.3+0.4);

In [None]:
def parse_code(l):
    if len(l.strip().split("\t")) == 2:
        a, b = l.strip().split("\t")
        return tuple(int(i) for i in a.split(' ')), b
    else:
        a, b, c = l.strip().split("\t")
        return tuple(int(i) for i in a.split(' ')), c

In [None]:
label_codes, label_names = zip(*[parse_code(l) for l in open(labels_path+"label_colors.txt")])
label_codes, label_names = list(label_codes), list(label_names)
label_codes[:5], label_names[:5]

len(label_names)

In [None]:
label_codes, label_names

In [None]:
code2id = {v:k for k,v in enumerate(label_codes)}

id2code = {k:v for k,v in enumerate(label_codes)}

In [None]:
code2id

In [None]:
failed_code = len(label_codes)+1
failed_code

In [None]:
n,r,c,ch = imgs.shape
n, r, c, ch

In [None]:
def conv_one_label(i): 
    res = np.zeros((r,c), 'uint8')
    for j in range(r): 
        for k in range(c):
            try: res[j,k] = code2id[tuple(lbs[i,j,k])]
            except: res[j,k] = failed_code
    return res

In [None]:
from concurrent.futures import ProcessPoolExecutor

In [None]:
def conv_all_labels():
    ex = ProcessPoolExecutor(8)
    return np.stack(ex.map(conv_one_label, range(n)))

In [None]:
labels_int = conv_all_labels()

In [None]:
# save_array('labels_int.bc', labels_int)

In [None]:
np.count_nonzero(labels_int==failed_code)

In [None]:
l = []
for i in range(len(labels_int)):
    if np.count_nonzero(labels_int[i]==failed_code) > 0:
        l.append(i)

print (len(l))

In [None]:
l

In [None]:
labels_int = np.delete(labels_int, l, axis=0)
len(labels_int)

In [None]:
labels_int.shape

In [None]:
np.count_nonzero(labels_int==failed_code)

In [None]:
imgs = np.delete(imgs, l, axis=0)
len(imgs)

In [None]:
def color_label(a): 
    r,c=a.shape
    res = np.zeros((r,c,3), 'uint8')
    for j in range(r): 
        for k in range(c):
            var = id2code[a[j,k]]
            res[j,k] = var
    return res

In [None]:
sg = segm_generator(imgs, lbs, 4, train=True)
b_img, b_label = next(sg)
plt.imshow(b_img[0]*0.3+0.4)

In [None]:
def dict_color_label(x):
    l = [code2id[tuple(i)] for i in x]
    l = np.array(l)
    return l

In [None]:
temp = dict_color_label(b_label[0])
plt.imshow(color_label(np.resize(temp, (224,224))))

## Creating Test Set

In [None]:
train_set = imgs[:46]
train_labels = labels_int[:46]

test_set = imgs[46:]
test_labels = labels_int[46:]

In [None]:
len(train_set), len(test_set), len(train_labels), len(test_labels)

In [None]:
plt.imshow(train_set[45]*0.3+0.4)

In [None]:
plt.imshow(color_label(train_labels[45]))

In [None]:
train_labels[0].shape

In [None]:
train_generator = segm_generator(train_set, train_labels, 3, train=True)
test_generator = segm_generator(test_set, test_labels, 3, train=False)

In [None]:
i,la = next(train_generator)

In [None]:
plt.imshow(i[0]*0.3+0.4)

In [None]:
# t = dict_color_label(la[0])
plt.imshow(color_label(np.resize(la[0], (224,224))))

In [None]:
la[0].shape

In [None]:
def relu(x): return Activation('relu')(x)
def dropout(x, p): return Dropout(p)(x) if p else x
def bn(x): return BatchNormalization(axis=-1)(x)
def relu_bn(x): return relu(bn(x))
def concat(xs): return merge(xs, mode='concat', concat_axis=-1)

In [None]:
def conv(x, nf, sz, wd, p, stride=1): 
    x = Convolution2D(nf, sz, sz, init='he_uniform', border_mode='same', 
                      subsample=(stride,stride), W_regularizer=l2(wd))(x)
    return dropout(x, p)

def conv_relu_bn(x, nf, sz=3, wd=0, p=0, stride=1): 
    return conv(relu_bn(x), nf, sz, wd=wd, p=p, stride=stride)

In [None]:
def dense_block(n,x,growth_rate,p,wd):
    added = []
    for i in range(n):
        b = conv_relu_bn(x, growth_rate, p=p, wd=wd)
        x = concat([x, b])
        added.append(b)
    return x,added

In [None]:
def transition_dn(x, p, wd):
#     x = conv_relu_bn(x, x.get_shape().as_list()[-1], sz=1, p=p, wd=wd)
#     return MaxPooling2D(strides=(2, 2))(x)
    return conv_relu_bn(x, x.get_shape().as_list()[-1], sz=1, p=p, wd=wd, stride=2)

In [None]:
def down_path(x, nb_layers, growth_rate, p, wd):
    skips = []
    for i,n in enumerate(nb_layers):
        x,added = dense_block(n,x,growth_rate,p,wd)
        skips.append(x)
        x = transition_dn(x, p=p, wd=wd)
    return skips, added

In [None]:
def transition_up(added, wd=0):
    x = concat(added)
    _,r,c,ch = x.get_shape().as_list()
    return Deconvolution2D(ch, 3, 3, (None,r*2,c*2,ch), init='he_uniform', 
               border_mode='same', subsample=(2,2), W_regularizer=l2(wd))(x)
#     x = UpSampling2D()(x)
#     return conv(x, ch, 2, wd, 0)

In [None]:
def up_path(added, skips, nb_layers, growth_rate, p, wd):
    for i,n in enumerate(nb_layers):
        x = transition_up(added, wd)
        x = concat([x,skips[i]])
        x,added = dense_block(n,x,growth_rate,p,wd)
    return x

In [None]:
def reverse(a): return list(reversed(a))

In [None]:
def create_tiramisu(nb_classes, img_input, nb_dense_block=6, 
    growth_rate=16, nb_filter=48, nb_layers_per_block=5, p=None, wd=0):
    
    if type(nb_layers_per_block) is list or type(nb_layers_per_block) is tuple:
        nb_layers = list(nb_layers_per_block)
    else: nb_layers = [nb_layers_per_block] * nb_dense_block

    x = conv(img_input, nb_filter, 3, wd, 0)
    skips,added = down_path(x, nb_layers, growth_rate, p, wd)
    x = up_path(added, reverse(skips[:-1]), reverse(nb_layers[:-1]), growth_rate, p, wd)
    
    x = conv(x, nb_classes, 1, wd, 0)
    _,r,c,f = x.get_shape().as_list()
    x = Reshape((-1, nb_classes))(x)
    return Activation('softmax')(x)

In [None]:
input_shape = (224,224,3)
img_input = Input(shape=input_shape)
x = create_tiramisu(32, img_input, nb_layers_per_block=[4,5,7,10,12,15], p=0.2, wd=1e-4)

In [None]:
model = Model(img_input, x)
model.compile(loss='sparse_categorical_crossentropy', 
              optimizer=keras.optimizers.RMSprop(1e-3), metrics=["accuracy"])

I have trained this model only for 2 epochs since I didn't have any more GPU computation time left. But you can train it for 500 epochs or so to get state of the art results. Do some finetuning(learning rate annealing also.)

In [None]:
model.fit_generator(train_generator, len(train_set), 2, verbose=2,
                    validation_data=test_generator, nb_val_samples=len(test_set))

In [None]:
model.save_weights('../tiramisu_2_iterations.h5')

In [None]:
predictions = model.predict_generator(test_generator, len(test_set))

In [None]:
predictions = np.argmax(predictions, axis=-1)

In [None]:
predictions[0].shape

In [None]:
plt.imshow(color_label(np.resize(predictions[6], (224,224))))

In [None]:
plt.imshow(test_labels[3])

In [None]:
j, t_la = next(test_generator)

In [None]:
plt.imshow(j[0]*0.3+0.4)

In [None]:
try_image = np.array(Image.open('Seq05VD_f05100.png').resize((224,224), Image.NEAREST))
try_image = try_image/255.
try_image-=mean
try_image/=std
try_image.shape

In [None]:
try_preds = model.predict(np.expand_dims(try_image, 0), 1)

In [None]:
try_preds = np.argmax(try_preds, axis=-1)
try_preds.shape

In [None]:
plt.imshow(color_label(np.resize(try_preds[0], (224,224))))

In [None]:
try_img = np.array(Image.open('Seq05VD_f05100.png').resize(image_size, Image.NEAREST))
try_img = try_img/255.
try_img-=mean
try_img/=std

In [None]:
try_label = np.array(Image.open('Seq05VD_f05100_L.png').resize(image_size, Image.NEAREST))

In [None]:
try_gen = segm_generator(np.expand_dims(try_img, 0), np.expand_dims(try_img, 0), 1, train=False)

In [None]:
prd = model.predict_generator(try_gen, 1)

In [None]:
prd = np.argmax(prd, axis=-1)

In [None]:
prd.shape

In [None]:
plt.imshow(color_label(np.resize(prd[0], (224,224))))