In [1]:
import tensorflow as tf
import numpy as np
import math
import matplotlib.pyplot as plt
import sys
%matplotlib inline

In [2]:
def conv_bn_relu(x, filters, is_train, relu=True):
    #print('conv_bn_relu')
    #print(x.shape)
    x = tf.layers.conv2d(x, use_bias=False, filters=filters, kernel_size=3,padding='SAME')
    x = tf.layers.batch_normalization(x, training=is_train)
    if relu:
        x = tf.nn.relu(x)
    #print(x.shape)
    #print()
    return x

In [3]:
def pool(x):
    #print('pool')
    #print(x.shape)
    shape = tf.shape(x)
    x, pool_inds = tf.nn.max_pool_with_argmax(x, ksize=[1,2,2,1],strides=[1,2,2,1], padding='VALID')
    pool_inds = tf.cast(pool_inds, tf.int32)
    #print(x.shape)
    #print()
    return x, pool_inds, shape

In [4]:
def unpool(x, pool_inds, shape):
    #print('unpool')
    #print(x.shape)
    inds_flat = tf.reshape(pool_inds,[-1,1])
    batch_range = tf.range(start=0,limit=tf.shape(x)[0],delta=1)
    batch_inds = tf.tile(tf.expand_dims(batch_range, axis=-1), [1,tf.size(pool_inds[0])])
    batch_inds_flat = tf.reshape(batch_inds, [-1,1])
    inds = tf.concat([batch_inds_flat, inds_flat], axis=1)
    x_flat = tf.reshape(x,[-1])
    shape_flat = [shape[0], tf.reduce_prod(shape[1:])]
    result = tf.scatter_nd(indices=inds, updates=x_flat, shape=shape_flat)
    result = tf.reshape(result, shape)
    #print(result.shape)
    #print()
    return result

In [5]:
def segnet(x, n_classes, is_train):
    conv1 = conv_bn_relu(x, 64, is_train)
    conv2 = conv_bn_relu(conv1, 64, is_train)
    pool1, inds1, shape1 = pool(conv2)
    
    conv3 = conv_bn_relu(pool1, 128, is_train)
    conv4 = conv_bn_relu(conv3, 128, is_train)
    pool2, inds2, shape2 = pool(conv4)

    conv5 = conv_bn_relu(pool2, 256, is_train)
    conv6 = conv_bn_relu(conv5, 256, is_train)
    conv7 = conv_bn_relu(conv6, 256, is_train)
    pool3, inds3, shape3 = pool(conv7)

    conv8 = conv_bn_relu(pool3, 512, is_train)
    conv9 = conv_bn_relu(conv8, 512, is_train)
    conv10 = conv_bn_relu(conv9, 512, is_train)
    pool4, inds4, shape4 = pool(conv10)

    conv11 = conv_bn_relu(pool4, 512, is_train)
    conv12 = conv_bn_relu(conv11, 512, is_train)
    conv13 = conv_bn_relu(conv12, 512, is_train)
    pool5, inds5, shape5 = pool(conv13)

    unpool1 = unpool(pool5, inds5, shape5)
    upconv1 = conv_bn_relu(unpool1, 512, is_train, relu=False)
    upconv2 = conv_bn_relu(upconv1, 512, is_train, relu=False)
    upconv3 = conv_bn_relu(upconv2, 512, is_train, relu=False)

    unpool2 = unpool(upconv3, inds4, shape4)
    upconv4 = conv_bn_relu(unpool2, 512, is_train, relu=False)
    upconv5 = conv_bn_relu(upconv4, 512, is_train, relu=False)
    upconv6 = conv_bn_relu(upconv5, 256, is_train, relu=False)

    unpool3 = unpool(upconv6, inds3, shape3)
    upconv7 = conv_bn_relu(unpool3, 256, is_train, relu=False)
    upconv8 = conv_bn_relu(upconv7, 256, is_train, relu=False)
    upconv9 = conv_bn_relu(upconv8, 128, is_train, relu=False)

    unpool4 = unpool(upconv9, inds2, shape2)
    upconv10 = conv_bn_relu(unpool4, 128, is_train, relu=False)
    upconv11 = conv_bn_relu(upconv10, 64, is_train, relu=False)

    unpool5 = unpool(upconv11, inds1, shape1)
    upconv12 = conv_bn_relu(unpool5, 64, is_train, relu=False)
    #upconv13 = conv_bn_relu(upconv12, 64)

    logits = conv_bn_relu(upconv12, n_classes, is_train, relu=False)
    
    return logits

In [6]:
train_files = open('SegNet/CamVid/train.txt').read().replace('/SegNet','SegNet').split('\n')[:-1]
valid_files = open('SegNet/CamVid/val.txt').read().replace('/SegNet','SegNet').split('\n')[:-1]
img_files, mask_files = zip(*[i.split() for i in train_files[:360]])
valid_img_files, valid_mask_files = zip(*[i.split() for i in valid_files])

In [24]:
downsample = 2

In [25]:
height = 360/2
width = 480/2

In [26]:
def batch_gen(batch_size, img_files, mask_files, ds=1):
    num_ex = len(img_files)
    shuffle = np.random.permutation(num_ex)
    batches_per_epoch = math.ceil(num_ex/batch_size)
    img_files = np.array(img_files)[shuffle]
    mask_files = np.array(mask_files)[shuffle]
    for i in range(batches_per_epoch):
        slc = slice(i*batch_size, (i+1)*batch_size)
        files = [img_files[slc], mask_files[slc]]
        imgs, masks = [np.array([plt.imread(j) for j in f])[:,::ds,::ds] 
                       for f in files]
        yield (i+1, imgs.astype(np.float32), (255*masks).astype(np.float32))

In [27]:
counts = np.zeros(12)
tg = batch_gen(1,img_files, mask_files)
eq = np.reshape(np.arange(12), [1,12])
for i,im,mk in tg:
    counts += np.sum(np.reshape(mk,[-1,1])==eq, axis=0)

In [28]:
weights = np.median(counts)/counts
weights

array([  0.24624816,   0.17903213,   4.28582449,   0.13155787,
         0.94207353,   0.42706214,   3.51309235,   3.7142853 ,
         0.69747728,   6.72559599,  14.52094855,   1.06551678])

In [29]:
tf.reset_default_graph()
x = tf.placeholder(name='x',dtype=tf.float32,shape=[None,height,width,3])
y = tf.placeholder(name='y',dtype=tf.int32,shape=[None,height,width])
is_train = tf.placeholder(name='is_train',dtype=tf.bool,shape=None)

n_classes = 12
weights = np.ones(n_classes)/n_classes
logits = segnet(x,n_classes,is_train)
preds = tf.argmax(logits, axis=-1, output_type=tf.int32)

y_one_hot = tf.one_hot(tf.cast(y, tf.int32), depth=n_classes, axis=-1)
preds_one_hot = tf.one_hot(preds, depth=n_classes, axis=-1)


class_counts = tf.reduce_sum(y_one_hot, axis=[0,1,2])
equal = tf.cast(tf.equal(y, preds), tf.float32)
num_right = tf.reduce_sum(tf.expand_dims(equal,axis=-1)*y_one_hot, axis=[0,1,2])
class_acc = num_right/class_counts

macc = tf.reduce_mean(equal)

intersection = tf.reduce_sum(y_one_hot*preds_one_hot, axis=[1,2])
union = tf.reduce_sum(y_one_hot, axis=[1,2]) + tf.reduce_sum(preds_one_hot, axis=[1,2]) - intersection
iou = tf.reduce_mean((intersection + 1e-10)/(union + 1e-10), axis=0)
miou = tf.reduce_mean(iou)


logprobs = tf.nn.log_softmax(logits)
losses = -tf.reduce_mean(logprobs*y_one_hot, axis=(0,1,2))
loss = tf.reduce_sum(losses*weights)

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    opt = tf.train.MomentumOptimizer(1e-1,momentum=0.9)
    train_step = opt.minimize(loss)


In [30]:
n_epochs = 70
batch_size = 10
num_ex = len(img_files)
num_valid = len(valid_img_files)
batches_per_epoch = math.ceil(num_ex/batch_size)
num_valid_batches = math.ceil(num_valid/batch_size)

In [31]:
train_loss = []
valid_loss = []
train_iou = []
valid_iou = []
train_acc = []
valid_acc = []

sess = tf.Session()
sess.run(tf.global_variables_initializer())

print('Training')

for i in range(1,n_epochs+1):
    train_gen = batch_gen(batch_size, img_files, mask_files, downsample)
    
    num_so_far = 0
    for n, imgs, masks in train_gen:
        _, t_loss, t_iou, t_acc = sess.run([train_step, loss, miou, macc], 
            {x: imgs, y: masks, is_train:True})
        
        train_loss.append(t_loss*len(imgs))
        train_iou.append(t_iou*len(imgs))
        train_acc.append(t_acc*len(imgs))
        num_so_far += len(imgs)
        mean_loss = np.sum(train_loss[-n:])/num_so_far
        mean_iou = np.sum(train_iou[-n:])/num_so_far
        mean_acc = np.sum(train_acc[-n:])/num_so_far
        sys.stdout.write('\rTraining: epoch {}/{}, batch {}/{}, loss:{:.4f}, iou:{:.4f}, acc:{:.4f}'.format(
            i,n_epochs,n,batches_per_epoch, mean_loss, mean_iou, mean_acc))
    
    valid_gen = batch_gen(batch_size, valid_img_files, valid_mask_files, downsample)
    
    sys.stdout.write('\r\n')
    
    num_so_far = 0
    
    v_loss = 0
    v_iou = 0
    v_acc = 0
    for n, imgs, masks in valid_gen:
        v_l, v_i, v_a = sess.run([loss, miou, macc], {x: imgs, y: masks, is_train:False})
        num_so_far += len(imgs)
        v_loss += v_l*len(imgs)
        v_iou += v_i*len(imgs)
        v_acc += v_a*len(imgs)
        
        out_str = 'Validation: batch {}/{}'.format(n,num_valid_batches,n,batches_per_epoch)
        sys.stdout.write('\r{}'.format(out_str))
        
    mean_loss = v_loss/num_so_far
    mean_iou = v_iou/num_so_far
    mean_acc = v_acc/num_so_far
    
    valid_loss.append(mean_loss)
    valid_iou.append(mean_iou)
    valid_acc.append(mean_acc)
    sys.stdout.write('\r{}, loss: {:.4f}, iou: {:.4f}, acc:{:.4f}'.format(out_str, mean_loss, mean_iou, mean_acc))
    sys.stdout.write('\r\n')

Training
Training: epoch 1/70, batch 36/36, loss:0.1921, iou:0.0763, acc:0.2739
Validation: batch 11/11, loss: 0.1883, iou: 0.0240, acc:0.2881
Training: epoch 2/70, batch 36/36, loss:0.1123, iou:0.1905, acc:0.6695
Validation: batch 11/11, loss: 0.1709, iou: 0.0252, acc:0.2915
Training: epoch 3/70, batch 36/36, loss:0.0928, iou:0.2666, acc:0.6932
Validation: batch 11/11, loss: 0.1662, iou: 0.0317, acc:0.3106
Training: epoch 4/70, batch 36/36, loss:0.0811, iou:0.2870, acc:0.7216
Validation: batch 11/11, loss: 0.1612, iou: 0.0371, acc:0.3288
Training: epoch 5/70, batch 36/36, loss:0.0713, iou:0.3303, acc:0.7703
Validation: batch 11/11, loss: 0.1588, iou: 0.0426, acc:0.3488
Training: epoch 6/70, batch 36/36, loss:0.0652, iou:0.3477, acc:0.7884
Validation: batch 11/11, loss: 0.1600, iou: 0.0534, acc:0.3658
Training: epoch 7/70, batch 36/36, loss:0.0599, iou:0.3578, acc:0.8007
Validation: batch 11/11, loss: 0.1532, iou: 0.0668, acc:0.4025
Training: epoch 8/70, batch 36/36, loss:0.0552, iou:0

G    C    mIOU G     C   mIOU 
84.0 54.6 46.3 96.1 83.9 73.3