In [48]:
import tensorflow as tf
import os
import numpy as np
import math
import matplotlib.pyplot as plt
import sys
from string import ascii_lowercase
%matplotlib inline

In [50]:
def start_block(x):
    if x.get_shape().as_list()[-1] != 3:
        x = conv_bn_relu(x, ['conv0', 'bn_conv0'], kernel_size=1, filters=3)
    x = conv_bn_relu(x, ['conv1', 'bn_conv1'], kernel_size=7, stride=2, filters=64)
    x = tf.layers.max_pooling2d(x, pool_size=2, strides=2)
    return x

def conv_bn(x, names, kernel_size, filters, stride=1, dilation_rate=1, training=False):
    x = tf.layers.conv2d(x, kernel_size=kernel_size, strides=stride, filters=filters, 
                         dilation_rate=dilation_rate, padding='SAME', name=names[0])
    x = tf.layers.batch_normalization(x, training=is_train, name=names[1])
    return x
    
def conv_bn_relu(x, names, kernel_size, filters, stride=1, dilation_rate=1, training=False):
    x = conv_bn(x, names, kernel_size, filters, stride, dilation_rate, training)
    x = tf.nn.relu(x)
    return x
    
def conv_block(x, name, filters, stride=1, dilation_rate=1, training=False):
    filters_in = x.get_shape().as_list()[-1]
    filters_out = filters*4
    out = conv_bn_relu(x, ['res%s_branch2a'%name, 'bn%s_branch2a'%name],
                       kernel_size=1, stride=stride, filters=filters, training=training)
    out = conv_bn_relu(out, ['res%s_branch2b'%name, 'bn%s_branch2b'%name],
                       kernel_size=3, filters=filters, 
                       dilation_rate=dilation_rate, 
                       training=training)
    out = conv_bn(out, ['res%s_branch2c'%name, 'bn%s_branch2c'%name],
                  kernel_size=1, filters=filters_out, training=training)
    if stride > 1 or filters_in != filters_out:
        x = conv_bn(x, ['res%s_branch1'%name, 'bn%s_branch1'%name], 
                    kernel_size=1, filters=filters_out, stride=stride, training=training)
    out = tf.nn.relu(x + out)
    return out

def res_block(x, num, filters, n_blocks, stride=1, dilation_rate=1, training=False):
    names = ['{}{}'.format(num, ascii_lowercase[i]) for i in range(n_blocks)]
    for i, name in enumerate(names):
        x = conv_block(x, name, filters, stride=1 if i else stride, 
                       training=training, dilation_rate=dilation_rate)
    return x

In [51]:
def deeplab_resnet50(x, training=False):
    rate = 1 if training else 2
    x = start_block(x)
    conv2 = res_block(x, 2, filters=64, n_blocks=3, training=training)
    x = res_block(conv2, 3, filters=128, n_blocks=4, stride=2, training=training)
    x = res_block(x, 4, filters=256, n_blocks=6, stride=2 if training else 1, 
                  dilation_rate=rate, training=training)
    x = res_block(x, 5, filters=512, n_blocks=3, dilation_rate=2*rate, training=training)
    x = res_block(x, 6, filters=512, n_blocks=3, dilation_rate=4*rate, training=training)
    x = res_block(x, 7, filters=512, n_blocks=3, dilation_rate=8*rate, training=training)
    x = res_block(x, 8, filters=512, n_blocks=3, dilation_rate=16*rate, training=training)
    return conv2, x

In [52]:
def aspp(x, training=False):
    rate = 1 if training else 2
    feature_map1 = conv_bn(x, ['aspp_conv1', 'aspp_bn1'], 1, 256, training=training)
    feature_map2 = conv_bn(x, ['aspp_conv2', 'aspp_bn2'], 3, 256, dilation_rate=6*rate, training=training)
    feature_map3 = conv_bn(x, ['aspp_conv3', 'aspp_bn3'], 3, 256, dilation_rate=12*rate, training=training)
    feature_map4 = conv_bn(x, ['aspp_conv4', 'aspp_bn4'], 3, 256, dilation_rate=18*rate, training=training)
    global_features = tf.reduce_mean(x, axis=[1,2], keep_dims=True)
    global_features = conv_bn(global_features, ['global_conv1', 'global_bn1'], 1, 256, training=training)
    global_features = tf.image.resize_bilinear(images=global_features, size=tf.shape(feature_map1)[1:3])
    concat = tf.concat([feature_map1, 
                        feature_map2, 
                        feature_map3, 
                        feature_map4, 
                        global_features], axis=-1)
    out = conv_bn(concat, ['aspp_out_conv1', 'aspp_out_bn1'], 1, 256, training=training)
    return out
    

In [53]:
def deeplab_encoder(x, training=False):
    low_level, res_out = deeplab_resnet50(x, training)
    aspp_out = aspp(res_out)
    return low_level, aspp_out

In [54]:
def deeplab_decoder(low_level, aspp_out, training=False):
    low_level = conv_bn(low_level, ['dec_res_conv1', 'dec_res_bn1'], 1, 48, training=training)
    aspp_out = tf.image.resize_bilinear(images=aspp_out, 
                                        size=tf.shape(aspp_out)[1:3]*(2+2*training))
    x = tf.concat([low_level, aspp_out], axis=-1)
    x = conv_bn(x, ['dec_conv1', 'dec_bn1'], 3, 256, training=training)
    x = conv_bn(x, ['dec_conv2', 'dec_bn2'], 3, 256, training=training)
    x = tf.image.resize_bilinear(images=x, size=tf.shape(x)[1:3]*4)
    return x

In [55]:
def deeplab(x, n_classes, training=False):
    low_level, aspp_out = deeplab_encoder(x, training)
    dec_out = deeplab_decoder(low_level, aspp_out, training)
    logits = conv_bn(dec_out, ['out_conv1', 'out_bn1'], kernel_size=1, filters=n_classes, training=training)
    return logits

In [56]:
h=w=256
downsample = 2
filespath = 'VOCdevkit/VOC2012/ImageSets/Segmentation'
train = open(os.path.join(filespath, 'train.txt')).read().split('\n')[:-1]
val = open(os.path.join(filespath, 'val.txt')).read().split('\n')[:-1]
imgs_path = 'VOCdevkit/VOC2012/JPEGImages/{}.jpg'
segs_path = 'VOCdevkit/VOC2012/SegmentationMap/{}.png'
img_files, mask_files = [[path.format(f) for f in train] for path in [imgs_path, segs_path]]
valid_img_files, valid_mask_files = [[path.format(f) for f in val] for path in [imgs_path, segs_path]]
n_epochs = 5
batch_size = 10
num_ex = len(img_files)
num_valid = len(valid_img_files)
batches_per_epoch = math.ceil(num_ex/batch_size)
num_valid_batches = math.ceil(num_valid/batch_size)

In [59]:
def read_img(img_file, img_type, height=None, width=None, downsample=None):
    assert img_type in ['jpeg', 'png']
    img_string = tf.read_file(img_file)
    if img_type == 'jpeg':
        img = tf.image.decode_jpeg(img_string)
    if img_type == 'png':
        img = tf.image.decode_png(img_string)
    if downsample is not None:
        img = img[::downsample, ::downsample]
    if height is not None and width is not None:
        img = tf.image.resize_image_with_crop_or_pad(img, height, width)
    img = tf.cast(img, tf.float32)
    return img
    
def read_imgs(files, dtypes, height=None, width=None, downsample=None):
    imgs = [read_img(file, dtype, height, width, downsample) 
            for file, dtype in zip(files, dtypes)]
    return imgs

def make_dataset(img_files, mask_files, 
                 batch_size, n_epochs=None, 
                 height=None, width=None,
                 downsample=None, shuffle=True):
    dataset = tf.data.Dataset.zip(tuple(tf.data.Dataset.from_tensor_slices(files) 
                                        for files in [img_files, mask_files]))
    if shuffle:
        dataset = dataset.shuffle(len(img_files))
    
    dataset = dataset.map(lambda x, y: read_imgs([x,y], ['jpeg', 'png'], height, width, downsample))
    dataset = dataset.batch(batch_size).repeat(n_epochs)
    return dataset

In [60]:
tf.reset_default_graph()

train_dataset = make_dataset(img_files, mask_files,batch_size,n_epochs,h,w,downsample)
valid_dataset = make_dataset(valid_img_files, valid_mask_files,batch_size,n_epochs,h,w,downsample)

dataset_handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(dataset_handle, 
                                              train_dataset.output_types,
                                              train_dataset.output_shapes)
train_iterator = train_dataset.make_one_shot_iterator()

valid_iterator = valid_dataset.make_one_shot_iterator()

x, y = iterator.get_next()

x.set_shape([batch_size, h, w, 3])
y.set_shape([batch_size, h, w, 4])

x = x/255
y = y[...,0]


is_train = tf.placeholder(name='is_train',dtype=tf.bool,shape=None)

n_classes = 21
weights = np.ones(n_classes)/n_classes
with tf.variable_scope('model'):
    logits = deeplab(x, n_classes=n_classes, training=False)
preds = tf.argmax(logits, axis=-1, output_type=tf.int32)

y_one_hot = tf.one_hot(tf.cast(y, tf.int32), depth=n_classes, axis=-1)
preds_one_hot = tf.one_hot(preds, depth=n_classes, axis=-1)

#mask = tf.expand_dims(1 - tf.cast(tf.equal(y,0), tf.float32), axis=-1)

mask = tf.expand_dims(tf.ones_like(y), axis=-1)

y_one_hot_masked = y_one_hot*mask
preds_one_hot_masked = preds_one_hot*mask

class_counts = tf.reduce_sum(y_one_hot, axis=[0,1,2])
equal = tf.cast(tf.equal(tf.cast(y, tf.int32), preds), tf.float32)
num_right = tf.reduce_sum(tf.expand_dims(equal,axis=-1)*y_one_hot_masked, axis=[0,1,2])
class_acc = num_right/class_counts

macc = tf.reduce_mean(equal)

intersection = tf.reduce_sum(y_one_hot_masked*preds_one_hot_masked, axis=[1,2])
union = tf.reduce_sum(y_one_hot_masked, axis=[1,2]) + tf.reduce_sum(preds_one_hot_masked, axis=[1,2]) - intersection
iou = tf.reduce_mean((intersection + 1e-10)/(union + 1e-10), axis=0)
miou = tf.reduce_mean(iou)


logprobs = tf.nn.log_softmax(logits)
losses = -tf.reduce_mean(logprobs*preds_one_hot_masked, axis=(0,1,2))
loss = tf.reduce_sum(losses*weights)

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    opt = tf.train.MomentumOptimizer(0.007,momentum=0.9)
    train_step = opt.minimize(loss)



In [64]:
train_loss = []
valid_loss = []
train_iou = []
valid_iou = []
train_acc = []
valid_acc = []

sess = tf.Session()

weights = np.load('resnet50.npz')
assigned_vars = [v for v in tf.global_variables() if v.name.replace('model/','') in weights.keys()]
unassigned_vars = set(tf.global_variables()) - set(assigned_vars)

deeplab_saver = tf.train.Saver(var_list = assigned_vars)
deeplab_saver.restore(sess, 'pho_seg_preproc_minus_convex_test_diff_new/notebooks/checkpoints/resnet50_ckpt')
sess.run(tf.variables_initializer(unassigned_vars))



INFO:tensorflow:Restoring parameters from pho_seg_preproc_minus_convex_test_diff_new/notebooks/checkpoints/resnet50_ckpt


In [None]:
print('Training')

for i in range(1,n_epochs+1):
    training_handle = sess.run(train_iterator.string_handle())
    num_so_far = 0
    for n in range(1,batches_per_epoch+1):
        _, t_loss, t_iou, t_acc = sess.run([train_step, loss, miou, macc], 
            {dataset_handle:training_handle, is_train:True})
        
        train_loss.append(t_loss)
        train_iou.append(t_iou)
        train_acc.append(t_acc)
        num_so_far += batch_size
        mean_loss = np.mean(train_loss[-n:])
        mean_iou = np.mean(train_iou[-n:])
        mean_acc = np.mean(train_acc[-n:])
        sys.stdout.write('\rTraining: epoch {}/{}, batch {}/{}, loss:{:.4f}, iou:{:.4f}, acc:{:.4f}'.format(
            i,n_epochs,n,batches_per_epoch, mean_loss, mean_iou, mean_acc))
    
    validation_handle = sess.run(valid_iterator.string_handle())
    
    sys.stdout.write('\r\n')
    
    
    v_loss = []
    v_iou = []
    v_acc = []
    for n in range(1,num_valid_batches+1):
        v_l, v_i, v_a = sess.run([loss, miou, macc], {dataset_handle:validation_handle, is_train:False})
        v_loss.append(v_l)
        v_iou.append(v_i)
        v_acc.append(v_a)
        
        out_str = 'Validation: batch {}/{}'.format(n,num_valid_batches)
        sys.stdout.write('\r{}'.format(out_str))
        
    mean_loss = np.mean(v_loss)
    mean_iou = np.mean(v_iou)
    mean_acc = np.mean(v_acc)
    
    valid_loss.append(mean_loss)
    valid_iou.append(mean_iou)
    valid_acc.append(mean_acc)
    sys.stdout.write('\r{}, loss: {:.4f}, iou: {:.4f}, acc:{:.4f}'.format(out_str, mean_loss, mean_iou, mean_acc))
    sys.stdout.write('\r\n')

Training
