# Read data

In [None]:
import numpy as np

import glob
import skimage.io as io
from PIL import Image

import os.path
import sys

import tensorflow as tf
from tensorflow.contrib import slim


sys.path.append("models/research/slim/")

## Get image and mask pair file names

In [None]:
%matplotlib inline

origin_images = [img for img in glob.glob("train_subset/*.tif") if 'mask' not in img]

def fimg_to_fmask(img_path):
    # convert an image file path into a corresponding mask file path 
    dirname, basename = os.path.split(img_path)
    maskname = basename.replace(".tif", "_mask.tif")
    return os.path.join(dirname, maskname)

paired_images = [(img, fimg_to_fmask(img)) for img in origin_images]

# check an image instance
img = io.imread('train_subset/1_1.tif')
print(type(img))
print(img.shape)
io.imshow(img)

## Construct Tfrecords binary data file

Adapt code from this blog http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/

In [None]:
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

tfrecords_filename = 'medical_image_segmentation.tfrecords'

writer = tf.python_io.TFRecordWriter(tfrecords_filename)

original_images = []

for img_path, segmentation_path in paired_images:
    
    img = np.array(Image.open(img_path))
    seg = np.array(Image.open(segmentation_path))
        
    height = img.shape[0]
    width = img.shape[1]
    
    original_images.append((img, seg))
    
    img_raw = img.tostring()
    seg_raw = seg.tostring()
    
    example = tf.train.Example(features=tf.train.Features(feature={
        'height': _int64_feature(height),
        'width': _int64_feature(width),
        'image_raw': _bytes_feature(img_raw),
        'mask_raw': _bytes_feature(seg_raw)}))
    
    writer.write(example.SerializeToString())

writer.close()

## Check images
check reconstructed images from tfrecords file match the original images

In [None]:
reconstructed_images = []

record_iterator = tf.python_io.tf_record_iterator(path=tfrecords_filename)

for string_record in record_iterator:
    
    example = tf.train.Example()
    example.ParseFromString(string_record)
    
    height = int(example.features.feature['height']
                                 .int64_list
                                 .value[0])
    
    width = int(example.features.feature['width']
                                .int64_list
                                .value[0])
    
    img_string = (example.features.feature['image_raw']
                                  .bytes_list
                                  .value[0])
    
    seg_string = (example.features.feature['mask_raw']
                                .bytes_list
                                .value[0])
    
    img_1d = np.fromstring(img_string, dtype=np.uint8)
    reconstructed_img = img_1d.reshape((height, width))
    
    seg_1d = np.fromstring(seg_string, dtype=np.uint8)
    
    reconstructed_seg = seg_1d.reshape((height, width))
    
    reconstructed_images.append((reconstructed_img, reconstructed_seg))
    
# check if the reconstructed images match the original images

for original_pair, reconstructed_pair in zip(original_images, reconstructed_images):
    
    img_pair_to_compare, seg_pair_to_compare = zip(original_pair,
                                                          reconstructed_pair)
    print(np.allclose(*img_pair_to_compare))
    print(np.allclose(*seg_pair_to_compare))

## Preprocess data

### load data from tfrecords file

In [None]:
def load_batch(tfrecords_filename, capacity, img_height, img_width, batch_size=32, num_epochs=10, is_training=False):
    
    filename_queue = tf.train.string_input_producer([tfrecords_filename], num_epochs=num_epochs)

    
    reader = tf.TFRecordReader()

    _, serialized_example = reader.read(filename_queue)

    features = tf.parse_single_example(
      serialized_example,
      features={
        'height': tf.FixedLenFeature([], tf.int64),
        'width': tf.FixedLenFeature([], tf.int64),
        'image_raw': tf.FixedLenFeature([], tf.string),
        'mask_raw': tf.FixedLenFeature([], tf.string)
        })

    image = tf.decode_raw(features['image_raw'], tf.uint8)
    segmentation = tf.decode_raw(features['mask_raw'], tf.uint8)
    
    height = tf.cast(features['height'], tf.int32)
    width = tf.cast(features['width'], tf.int32)
    
    image_shape = tf.stack([height, width])
    segmentation_shape = tf.stack([height, width])
    
    image = tf.reshape(image, image_shape)
    image = tf.expand_dims(image, -1)
    segmentation = tf.reshape(segmentation, segmentation_shape)
    segmentation = tf.expand_dims(segmentation, -1)
        
    image_size_const = tf.constant((img_height, img_width, 1), dtype=tf.int32)
    segmentation_size_const = tf.constant((img_height, img_width, 1), dtype=tf.int32)

    images, segmentations = tf.train.shuffle_batch([image, segmentation],
                                                 batch_size=batch_size,
                                                 capacity=capacity,
                                                 num_threads=2,
                                                 min_after_dequeue=10,
                                                shapes=[[img_height, img_width, 1], [img_height, img_width, 1]])
    
    return images, segmentations

### loading data from tfrecords file demo

In [None]:
%matplotlib inline

TFRECORDS_FILENAME = 'medical_image_segmentation.tfrecords'
IMAGE_HEIGHT = 420
IMAGE_WIDTH = 580
CAPACITY = 599

images, segmentations = load_batch(TFRECORDS_FILENAME, capacity=CAPACITY, img_height=IMAGE_HEIGHT, img_width=IMAGE_WIDTH,
                                   batch_size=32, num_epochs=10, is_training=True)


init_op = tf.group(tf.global_variables_initializer(),
                   tf.local_variables_initializer())

with tf.Session() as sess:
    
    sess.run(init_op)
    
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    
    for i in range(3):
    
        imgs, segs = sess.run([images, segmentations])
        
        imgs = np.concatenate([imgs, imgs, imgs], axis=3)
        segs = np.concatenate([segs, segs, segs], axis=3)

        
        print(imgs[0, :, :, :].shape)
        
        print('current batch')
        
        for j in range(3):
        
            io.imshow(imgs[j, :, :, :])
            io.show()

            io.imshow(segs[j, :, :, :])
            io.show()
        
    coord.request_stop()
    coord.join(threads)

### downsampling and upsampling

In [None]:
def get_kernel_size(factors):
    """
    Find the kernel size given the desired factor of upsampling.
    """
    return [2 * factor - factor % 2 for factor in factors]


def upsample_filt(sizes):
    """
    Make a 2D bilinear kernel suitable for upsampling of the given (h, w) size.
    """
    factors = [(size + 1) // 2 for size in sizes]
    center = [0]*len(sizes)
    for i in range(len(sizes)):
        if sizes[i] % 2 == 1:
            center[i] = factors[i] - 1
        else:
            center[i] = factors[i] - 0.5
        og = np.ogrid[:sizes[0], :sizes[1]]
    return (1 - abs(og[0] - center[0]) / factors[0]) * (1 - abs(og[1] - center[1]) / factors[1])


def bilinear_upsample_weights(factors, number_of_classes):
    """
    Create weights matrix for transposed convolution with bilinear filter
    initialization.
    """
    
    filter_sizes = get_kernel_size(factors)
    
    weights = np.zeros((filter_sizes[0],
                        filter_sizes[1],
                        number_of_classes,
                        number_of_classes), dtype=np.float32)
    
    upsample_kernel = upsample_filt(filter_sizes)
    
    for i in range(number_of_classes):
        
        weights[:, :, i, i] = upsample_kernel
    
    return weights

# CNN model architecture

- Fully Convolutional Network: https://www.cv-foundation.org/openaccess/content_cvpr_2015/html/Long_Fully_Convolutional_Networks_2015_CVPR_paper.html

- One FCN-8 model implementation: https://github.com/warmspringwinds/tf-image-segmentation

- Tensorflow slim VGG16 model with FCN feature: https://github.com/tensorflow/models/blob/master/research/slim/nets/vgg.py

In [None]:
def vgg_arg_scope(weight_decay=0.0005):
  with slim.arg_scope([slim.conv2d],
                      activation_fn=tf.nn.relu,
                      weights_regularizer=slim.l2_regularizer(weight_decay),
                      biases_initializer=tf.zeros_initializer(), padding='SAME') as arg_sc:
      return arg_sc

In [None]:
def vgg(inputs, num_classes=2, is_training=True, dropout_keep_prob=0.5, scope='vgg', fc_conv_padding='SAME'):
  with tf.variable_scope(scope, 'vgg', [inputs]) as sc:
    end_points_collection = sc.name + '_end_points'
    with slim.arg_scope([slim.conv2d, slim.max_pool2d], outputs_collections=end_points_collection):
      # [420, 580, 1]
      net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1') # [420, 580, 64]
      net = slim.max_pool2d(net, [2, 2], scope='pool1') # [210, 290, 64]
      net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2') # [210, 290, 128]
      net = slim.max_pool2d(net, [2, 2], scope='pool2') # [105, 145, 128]
      net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3') # [105, 145, 256]
      net = slim.max_pool2d(net, [3, 5], scope='pool3') # [35, 29, 256]
      net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4') # [35, 29, 512]
      net = slim.max_pool2d(net, [5, 1], scope='pool4') # [7, 29, 512] 
      # convolution effect: 60*20  
      # Use conv2d instead of fully_connected layers.
      net = slim.conv2d(net, 1024, [7, 29], padding=fc_conv_padding, scope='fc5') # [7,29, 1024]
      net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
                         scope='dropout5')
      net = slim.conv2d(net, 1024, [1, 1], scope='fc6') # [7, 29, 1024]
      net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
                         scope='dropout6')
      net = slim.conv2d(net, num_classes, [1, 1],
                        activation_fn=None,
                        normalizer_fn=None,
                        scope='fc7') # [7, 29, 2]
      # Convert end_points_collection into an end_point dict.
      end_points = slim.utils.convert_collection_to_dict(end_points_collection)
      return net, end_points

In [None]:
def fcn(image_batch_tensor, number_of_classes=2, is_training=False):

    image_batch_float = tf.to_float(image_batch_tensor)

    upsample_filter_factor_5_1_np = bilinear_upsample_weights(factors=[5, 1], number_of_classes=number_of_classes)
    upsample_filter_factor_3_5_np = bilinear_upsample_weights(factors=[3, 5], number_of_classes=number_of_classes)    
    upsample_filter_factor_4_4_np = bilinear_upsample_weights(factors=[4, 4], number_of_classes=number_of_classes)

    upsample_filter_factor_5_1_tensor = tf.constant(upsample_filter_factor_5_1_np)
    upsample_filter_factor_3_5_tensor = tf.constant(upsample_filter_factor_3_5_np)
    upsample_filter_factor_4_4_tensor = tf.constant(upsample_filter_factor_4_4_np)

    with tf.variable_scope("fcn")  as fcn_scope:
        with slim.arg_scope(vgg_arg_scope()):
            last_layer_logits, end_points = vgg(image_batch_float,
                                                       num_classes=number_of_classes,
                                                       is_training=is_training,
                                                       fc_conv_padding='SAME')
            last_layer_logits_shape = tf.shape(last_layer_logits)
            last_layer_upsampled_by_factor_5_1_logits_shape = tf.stack([
                                                                  last_layer_logits_shape[0],
                                                                  last_layer_logits_shape[1] * 5,
                                                                  last_layer_logits_shape[2] * 1,
                                                                  last_layer_logits_shape[3]
                                                                 ])
            last_layer_upsampled_by_factor_5_1_logits = tf.nn.conv2d_transpose(last_layer_logits,
                                                                             upsample_filter_factor_5_1_tensor,
                                                                             output_shape=last_layer_upsampled_by_factor_5_1_logits_shape,
                                                                             strides=[1, 5, 1, 1])

            
            pool3_features = end_points['fcn/vgg/pool3']
            pool3_logits = slim.conv2d(pool3_features,
                                       number_of_classes,
                                       [1, 1],
                                       activation_fn=None,
                                       normalizer_fn=None,
                                       weights_initializer=tf.zeros_initializer,
                                       scope='pool3_fc')
            fused_last_layer_and_pool3_logits = pool3_logits + last_layer_upsampled_by_factor_5_1_logits
            fused_last_layer_and_pool3_logits_shape = tf.shape(fused_last_layer_and_pool3_logits)
            fused_last_layer_and_pool3_upsampled_by_factor_3_5_logits_shape = tf.stack([
                                                                          fused_last_layer_and_pool3_logits_shape[0],
                                                                          fused_last_layer_and_pool3_logits_shape[1] * 3,
                                                                          fused_last_layer_and_pool3_logits_shape[2] * 5,
                                                                          fused_last_layer_and_pool3_logits_shape[3]
                                                                         ])
            fused_last_layer_and_pool3_upsampled_by_factor_3_5_logits = tf.nn.conv2d_transpose(fused_last_layer_and_pool3_logits,
                                                                        upsample_filter_factor_3_5_tensor,
                                                                        output_shape=fused_last_layer_and_pool3_upsampled_by_factor_3_5_logits_shape,
                                                                        strides=[1, 3, 5, 1])
            
            
            pool2_features = end_points['fcn/vgg/pool2']
            pool2_logits = slim.conv2d(pool2_features,
                                       number_of_classes,
                                       [1, 1],
                                       activation_fn=None,
                                       normalizer_fn=None,
                                       weights_initializer=tf.zeros_initializer,
                                       scope='pool2_fc')
            fused_last_layer_and_pool3_logits_and_pool_2_logits = pool2_logits + \
                                            fused_last_layer_and_pool3_upsampled_by_factor_3_5_logits
            fused_last_layer_and_pool3_logits_and_pool_2_logits_shape = tf.shape(fused_last_layer_and_pool3_logits_and_pool_2_logits)
            fused_last_layer_and_pool3_logits_and_pool_2_upsampled_by_factor_4_4_logits_shape = tf.stack([
                                                                          fused_last_layer_and_pool3_logits_and_pool_2_logits_shape[0],
                                                                          fused_last_layer_and_pool3_logits_and_pool_2_logits_shape[1] * 4,
                                                                          fused_last_layer_and_pool3_logits_and_pool_2_logits_shape[2] * 4,
                                                                          fused_last_layer_and_pool3_logits_and_pool_2_logits_shape[3]
                                                                         ])
            fused_last_layer_and_pool3_logits_and_pool_2_upsampled_by_factor_4_4_logits = tf.nn.conv2d_transpose(fused_last_layer_and_pool3_logits_and_pool_2_logits,
                                                                        upsample_filter_factor_4_4_tensor,
                                                                        output_shape=fused_last_layer_and_pool3_logits_and_pool_2_upsampled_by_factor_4_4_logits_shape,
                                                                        strides=[1, 4, 4, 1])
            
            
            fcn_5_1_variables_mapping = {}
            fcn_variables = slim.get_variables(fcn_scope)
            for variable in fcn_variables:
                if 'pool2_fc' in variable.name:
                    continue
                original_fcn_5_1_checkpoint_string = 'fcn/' +  variable.name[len(fcn_scope.original_name_scope):-2]
                fcn_5_1_variables_mapping[original_fcn_5_1_checkpoint_string] = variable

                
    return fused_last_layer_and_pool3_logits_and_pool_2_upsampled_by_factor_4_4_logits, fcn_5_1_variables_mapping

# Model training

In [None]:
TFRECORDS_FILENAME = 'medical_image_segmentation.tfrecords'
IMAGE_HEIGHT = 420
IMAGE_WIDTH = 580
CAPACITY = 21
BATCH_SIZE= 2
NUM_EPOCHS = CAPACITY / BATCH_SIZE + 1
NUMBER_OF_CLASSES = 2
LEARNING_RATE = 0.1

train_log_dir = 'fcn_model_checkpoints/'
if not tf.gfile.Exists(train_log_dir):
  tf.gfile.MakeDirs(train_log_dir)
print('Will save model to %s' % train_log_dir)

with tf.Graph().as_default():
    #tf.logging.set_verbosity(tf.logging.INFO)

    images, segmentations = load_batch(TFRECORDS_FILENAME, capacity=CAPACITY, 
                                       img_height=IMAGE_HEIGHT, img_width=IMAGE_WIDTH,
                                       batch_size=BATCH_SIZE, num_epochs=NUM_EPOCHS, is_training=True)
            
    # print('images shape: ', images.shape)

    logits, _ = fcn(image_batch_tensor=images,
                    number_of_classes=NUMBER_OF_CLASSES,
                    is_training=True)
    
    flat_logits = tf.reshape(tensor=logits, shape=(-1, NUMBER_OF_CLASSES))
    flat_segs = tf.reshape(tensor=segmentations, shape=(-1, NUMBER_OF_CLASSES))
    # print("logits shape: ", flat_logits.shape)
    # print("labels shape: ", flat_segs)
    cross_entropies = tf.nn.softmax_cross_entropy_with_logits(logits=flat_logits, labels=flat_segs)
    cross_entropy_sum = tf.reduce_sum(cross_entropies)

    tf.summary.scalar('losses/Total_Loss', cross_entropy_sum)
  
    optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
    train_op = slim.learning.create_train_op(cross_entropy_sum, optimizer)

    # Run the training:
    final_loss = slim.learning.train(
        train_op,
        logdir=train_log_dir,
        number_of_steps=1,
        save_summaries_secs=2,
        save_interval_secs=2)
  
    print('Finished training. Final batch loss %d' % final_loss)

# Model Testing