In [1]:
%matplotlib inline
import importlib
import utils2; importlib.reload(utils2)
from utils2 import *

Using TensorFlow backend.


In [2]:
limit_mem()

### Tiramisu / Camvid

In [3]:
# Setup for training images
PATH = '/data/code/SegNet-Tutorial/CamVid/'
frames_path = PATH+'train/'
labels_path = PATH+'trainannot/'


# fnames: file name of each training image
fnames = glob.glob(frames_path+'*.png')
lnames = [labels_path+os.path.basename(fn) for fn in fnames]
img_sz = (480,360)

def open_image(fn): return np.array(Image.open(fn))

imgs = np.stack([open_image(fn) for fn in fnames])
labels = np.stack([open_image(fn) for fn in lnames])

imgs = imgs/255.

n,r,c,ch = imgs.shape
imgs-=0.4
imgs/=0.3

In [4]:
# Setup for test images
PATH = '/data/code/SegNet-Tutorial/CamVid/'
test_frames_path = PATH+'test/'
test_labels_path = PATH+'testannot/'


# fnames: file name of each training image
test_fnames = glob.glob(test_frames_path+'*.png')
test_lnames = [test_labels_path+os.path.basename(fn) for fn in test_fnames]
img_sz = (480,360)

def open_image(fn): return np.array(Image.open(fn))

test_imgs = np.stack([open_image(fn) for fn in test_fnames])
test_labels = np.stack([open_image(fn) for fn in test_lnames])

test_imgs = test_imgs/255.

test_imgs-=0.4
test_imgs/=0.3

In [6]:
# Preprocessing

# Generator demo
class BatchIndices(object):
    def __init__(self, n, bs, shuffle=False):
        self.n,self.bs,self.shuffle = n,bs,shuffle
        self.lock = threading.Lock()
        self.reset()

    def reset(self):
        self.idxs = (np.random.permutation(self.n) 
                     if self.shuffle else np.arange(0, self.n))
        self.curr = 0

    def __next__(self):
        with self.lock:
            if self.curr >= self.n: self.reset()
            ni = min(self.bs, self.n-self.curr)
            res = self.idxs[self.curr:self.curr+ni]
            self.curr += ni
            return res
        
# segmentation generator
class segm_generator(object):
    def __init__(self, x, y, bs=64, out_sz=(224,224), train=True):
        self.x, self.y, self.bs, self.train = x,y,bs,train
        self.n, self.ri, self.ci, _ = x.shape
        self.idx_gen = BatchIndices(self.n, bs, train)
        self.ro, self.co = out_sz
        self.ych = self.y.shape[-1] if len(y.shape)==4 else 1

    def get_slice(self, i,o):
        start = random.randint(0, i-o) if self.train else (i-o)
        return slice(start, start+o)

    def get_item(self, idx):
        slice_r = self.get_slice(self.ri, self.ro)
        slice_c = self.get_slice(self.ci, self.co)
        x = self.x[idx, slice_r, slice_c]
        y = self.y[idx, slice_r, slice_c]
        if self.train and (random.random()>0.5): 
            y = y[:,::-1]
            x = x[:,::-1]
        return x, y

    def __next__(self):
        idxs = next(self.idx_gen)
        items = (self.get_item(idx) for idx in idxs)
        xs,ys = zip(*items)
        return np.stack(xs), np.stack(ys).reshape(len(ys), -1, self.ych)

In [7]:
# label name and lavel color codes
label_names = ['Sky', 'Building', 'Pole', 
               'Road', 'Pavement', 'Tree', 
               'SignSymbol', 'Fence', 'Car', 
               'Pedestrian', 'Bicyclist', 'Unlabelled']

label_codes = [(128,128,128),
               (128,0,0),
               (192,192,128),
               (128,64,128),
               (60,40,222),
               (128,128,0),
               (192,128,128),
               (64,64,128),
               (64,0,128),
               (64,64,0),
               (0,128,192),
               (0,128,192)]

In [8]:
# convert id label to color label
def color_label(a): 
    r,c=a.shape
    res = np.zeros((r,c,3), 'uint8')
    for j in range(r): 
        for k in range(c):
            o=label_codes[a[j,k]]
            res[j,k] = o
    return res

In [11]:
# Prepare for trn(training images) and trn_labels(training 'id' labels)
trn = imgs
trn_labels = labels

# Prepare for test(test images) and test_labels(test 'id' labels)
test = test_imgs
test_labels = test_labels

# number of training and test images
rnd_trn = len(trn_labels)
rnd_test = len(test_labels)              

In [12]:
## The Tiramisu network
def relu(x): return Activation('relu')(x)
def dropout(x, p): return Dropout(p)(x) if p else x
#def bn(x): return BatchNormalization(mode=2, axis=-1)(x)
def bn(x): return x
def relu_bn(x): return relu(bn(x))
def concat(xs): return merge(xs, mode='concat', concat_axis=-1)

def conv(x, nf, sz, wd, p, stride=1): 
    x = Convolution2D(nf, sz, sz, init='he_uniform', border_mode='same', 
                      subsample=(stride,stride), W_regularizer=l2(wd))(x)
    return dropout(x, p)

def conv_relu_bn(x, nf, sz=3, wd=0, p=0, stride=1): 
    return conv(relu_bn(x), nf, sz, wd=wd, p=p, stride=stride)

def dense_block(n,x,growth_rate,p,wd):
    added = []
    for i in range(n):
        b = conv_relu_bn(x, growth_rate, p=p, wd=wd)
        x = concat([x, b])
        added.append(b)
    return x,added

def transition_dn(x, p, wd):
#     x = conv_relu_bn(x, x.get_shape().as_list()[-1], sz=1, p=p, wd=wd)
#     return MaxPooling2D(strides=(2, 2))(x)
    return conv_relu_bn(x, x.get_shape().as_list()[-1], sz=1, p=p, wd=wd, stride=2)

def down_path(x, nb_layers, growth_rate, p, wd):
    skips = []
    for i,n in enumerate(nb_layers):
        x,added = dense_block(n,x,growth_rate,p,wd)
        skips.append(x)
        x = transition_dn(x, p=p, wd=wd)
    return skips, added

def transition_up(added, wd=0):
    x = concat(added)
    _,r,c,ch = x.get_shape().as_list()
    return Deconvolution2D(ch, 3, 3, (None,r*2,c*2,ch), init='he_uniform', 
               border_mode='same', subsample=(2,2), W_regularizer=l2(wd))(x)
#     x = UpSampling2D()(x)
#     return conv(x, ch, 2, wd, 0)

def up_path(added, skips, nb_layers, growth_rate, p, wd):
    for i,n in enumerate(nb_layers):
        x = transition_up(added, wd)
        x = concat([x,skips[i]])
        x,added = dense_block(n,x,growth_rate,p,wd)
    return x

## Build the tiramisu model
def reverse(a): return list(reversed(a))

def create_tiramisu(nb_classes, img_input, nb_dense_block=6, 
    growth_rate=16, nb_filter=48, nb_layers_per_block=5, p=None, wd=0):
    
    if type(nb_layers_per_block) is list or type(nb_layers_per_block) is tuple:
        nb_layers = list(nb_layers_per_block)
    else: nb_layers = [nb_layers_per_block] * nb_dense_block

    x = conv(img_input, nb_filter, 3, wd, 0)
    skips,added = down_path(x, nb_layers, growth_rate, p, wd)
    x = up_path(added, reverse(skips[:-1]), reverse(nb_layers[:-1]), growth_rate, p, wd)
    
    x = conv(x, nb_classes, 1, wd, 0)
    _,r,c,f = x.get_shape().as_list()
    x = Reshape((-1, nb_classes))(x)
    return Activation('softmax')(x)

## Train the network
limit_mem()
input_shape = (224,224,3)
img_input = Input(shape=input_shape)

x = create_tiramisu(12, img_input, nb_layers_per_block=[4,5,7,10,12,15], p=0.2, wd=1e-4)

model = Model(img_input, x)
gen = segm_generator(trn, trn_labels, 3, train=True)
gen_test = segm_generator(test, test_labels, 3, train=False)
model.compile(loss='sparse_categorical_crossentropy', 
              optimizer=keras.optimizers.RMSprop(1e-3), metrics=["accuracy"])
model.optimizer=keras.optimizers.RMSprop(1e-3, decay=1-0.99995)
#model.optimizer=keras.optimizers.RMSprop(1e-3)
K.set_value(model.optimizer.lr, 1e-3)

In [None]:
# start the training process
model.fit_generator(gen, rnd_trn, 100, verbose=2, 
                    validation_data=gen_test, nb_val_samples=rnd_test)

Epoch 1/100


In [None]:
# Save network weights
model.save_weights(PATH+'results/tiramisu_net.h5')

In [None]:
# To modify

In [None]:
# with regularization 
lrg_sz = (352,480)
gen = segm_generator(trn, trn_labels, 2, out_sz=lrg_sz, train=True)
gen_test = segm_generator(test, test_labels, 2, out_sz=lrg_sz, train=False)

lrg_shape = lrg_sz+(3,)
lrg_input = Input(shape=lrg_shape)

x = create_tiramisu(12, lrg_input, nb_layers_per_block=[4,5,7,10,12,15], p=0.2, wd=1e-4)
lrg_model = Model(lrg_input, x)
lrg_model.compile(loss='sparse_categorical_crossentropy', 
              optimizer=keras.optimizers.RMSprop(1e-4), metrics=["accuracy"])

In [None]:
# load previous tiramisu
lrg_model.load_weights(PATH+'results/tiramisu_net.h5')

In [None]:
lrg_model.fit_generator(gen, rnd_trn, 100, verbose=2, 
                    validation_data=gen_test, nb_val_samples=rnd_test)