In [1]:
import time
import os

from medpy.io import load, save
import numpy as np

import tflearn
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.conv import conv_2d, max_pool_2d, upsample_2d

import tensorflow as tf

In [2]:
dataPath = './data/'
modelPath = './model/'
fine_tune = True
loss_method = 'weighted_cross_entropy'

tf.reset_default_graph()
batch_size = 16
display_step = 20

# Network Parameters
tf.reset_default_graph()
width = 256
height = 256
n_channels = 1
n_classes = 2 # total classes (brain, non-brain)

x = tf.placeholder(tf.float32, [None, width, height, n_channels])
y = tf.placeholder(tf.float32, [None, width, height, n_classes])
lr = tf.placeholder(tf.float32)
weights = tf.placeholder(tf.float32, [batch_size*width*height])

NumberOfSamples = 1259


In [3]:
def generate_batch():
    for samples in generate_samples():
        image_batch = images[samples]
        label_batch = labels[samples]
        for i in range(image_batch.shape[0]):
            image_batch[i], label_batch[i] = augment_sample(image_batch[i], label_batch[i])
        yield(image_batch, label_batch)

def generate_samples():
    n_samples = NumberOfSamples
    n_epochs = 1000
    n_batches = n_samples/batch_size
    for _ in range(n_epochs):
        sample_ids = np.random.permutation(n_samples)
        for i in range(int(n_batches)):
            inds = slice(i*batch_size, (i+1)*batch_size)
            yield sample_ids[inds]

def augment_sample(image, label):

    image = image
    label = label
    
    return(image, label)

In [4]:
conv1 = conv_2d(x, 32, 3, activation='relu', padding='same', regularizer="L2")
conv1 = conv_2d(conv1, 32, 3, activation='relu', padding='same', regularizer="L2")
pool1 = max_pool_2d(conv1, 2)

conv2 = conv_2d(pool1, 64, 3, activation='relu', padding='same', regularizer="L2")
conv2 = conv_2d(conv2, 64, 3, activation='relu', padding='same', regularizer="L2")
pool2 = max_pool_2d(conv2, 2)

conv3 = conv_2d(pool2, 128, 3, activation='relu', padding='same', regularizer="L2")
conv3 = conv_2d(conv3, 128, 3, activation='relu', padding='same', regularizer="L2")
pool3 = max_pool_2d(conv3, 2)

conv4 = conv_2d(pool3, 256, 3, activation='relu', padding='same', regularizer="L2")
conv4 = conv_2d(conv4, 256, 3, activation='relu', padding='same', regularizer="L2")
pool4 = max_pool_2d(conv4, 2)

conv5 = conv_2d(pool4, 512, 3, activation='relu', padding='same', regularizer="L2")
conv5 = conv_2d(conv5, 512, 3, activation='relu', padding='same', regularizer="L2")

up6 = upsample_2d(conv5,2)
up6 = tflearn.layers.merge_ops.merge([up6, conv4], 'concat',axis=3)
conv6 = conv_2d(up6, 256, 3, activation='relu', padding='same', regularizer="L2")
conv6 = conv_2d(conv6, 256, 3, activation='relu', padding='same', regularizer="L2")

up7 = upsample_2d(conv6,2)
up7 = tflearn.layers.merge_ops.merge([up7, conv3],'concat', axis=3)
conv7 = conv_2d(up7, 128, 3, activation='relu', padding='same', regularizer="L2")
conv7 = conv_2d(conv7, 128, 3, activation='relu', padding='same', regularizer="L2")

up8 = upsample_2d(conv7,2)
up8 = tflearn.layers.merge_ops.merge([up8, conv2],'concat', axis=3)
conv8 = conv_2d(up8, 64, 3, activation='relu', padding='same', regularizer="L2")
conv8 = conv_2d(conv8, 64, 3, activation='relu', padding='same', regularizer="L2")

up9 = upsample_2d(conv8,2)
up9 = tflearn.layers.merge_ops.merge([up9, conv1],'concat', axis=3)
conv9 = conv_2d(up9, 32, 3, activation='relu', padding='same', regularizer="L2")
conv9 = conv_2d(conv9, 32, 3, activation='relu', padding='same', regularizer="L2")

pred = conv_2d(conv9, 2, 1,  activation='linear', padding='valid')

pred_reshape = tf.reshape(pred, [-1, n_classes])

Instructions for updating:
Use tf.initializers.variance_scaling instead with distribution=uniform to get equivalent behavior.


In [5]:
images = np.zeros((NumberOfSamples, width, height, n_channels))
labels = np.zeros((NumberOfSamples, width, height, n_classes))

slice_counter = 0
for path, subdirs, files in os.walk(dataPath):
    for name in files:
        if not("mask") in name and "nii" in name:
            image_data, image_header = load(os.path.join(path, name)) # Load data
            mask_data, mask_header = load(os.path.join(path, 'mask' + name)) # Load mask
            image_data = np.moveaxis(image_data, -1, 0) # Bring the last dim to the first
            mask_data = np.moveaxis(mask_data, -1, 0) # Bring the last dim to the first
            
            images[slice_counter:slice_counter + image_data.shape[0], :, :, 0] = image_data / np.max(image_data)

            labels[slice_counter:slice_counter + image_data.shape[0], :, :, 0] = mask_data
            labels[slice_counter:slice_counter + image_data.shape[0], :, :, 1] = 1 - mask_data

In [6]:
######################################
# Define loss and optimizer
pred_reshape = tf.reshape(pred, [batch_size * width * height, n_classes])
y_reshape = tf.reshape(y, [batch_size * width * height, n_classes])

if loss_method == 'cross_entropy':
    error = tf.nn.softmax_cross_entropy_with_logits(labels = y_reshape , logits = pred_reshape)
    cost = tf.reduce_mean(error)

elif loss_method == 'weighted_cross_entropy':
    error = tf.nn.softmax_cross_entropy_with_logits(labels = y_reshape , logits = pred_reshape)
    cost = tf.reduce_mean(error)

elif loss_method == 'dice':
    intersection = tf.reduce_sum(pred_reshape * y_reshape)
    cost = -(2 * intersection + 1)/(tf.reduce_sum(pred_reshape) + tf.reduce_sum(y_reshape) + 1)
    
else:
    raise NotImplementedError

optimizer = tf.train.AdamOptimizer(learning_rate=lr).minimize(cost)

# Evaluate model
correct_pred = tf.equal(tf.argmax(pred, -1), tf.argmax(y, -1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Initializing the variables
init = tf.global_variables_initializer()
arg_labels = np.argmax(labels, axis = -1)
class_weights = np.zeros(n_classes)
for i in range(n_classes):
    class_weights[i] = 1 / np.mean(arg_labels == i) ** 0.3
class_weights /= np.sum(class_weights)

sess = tf.Session()
sess.run(init)

saver = tf.train.Saver()
model_path = os.path.join(modelPath, 'Unet.ckpt')

learning_rate = 0.00001

Instructions for updating:

Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.

See @{tf.nn.softmax_cross_entropy_with_logits_v2}.

INFO:tensorflow:Restoring parameters from ./model/Fetal_2D_norm.ckpt


DataLossError: Unable to open table file ./model/Fetal_2D_norm.ckpt: Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator?
	 [[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, ..., DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]

Caused by op u'save/RestoreV2', defined at:
  File "/usr/lib/python2.7/runpy.py", line 174, in _run_module_as_main
    "__main__", fname, loader, pkg_name)
  File "/usr/lib/python2.7/runpy.py", line 72, in _run_code
    exec code in run_globals
  File "/usr/local/lib/python2.7/dist-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/usr/local/lib/python2.7/dist-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/usr/local/lib/python2.7/dist-packages/ipykernel/kernelapp.py", line 486, in start
    self.io_loop.start()
  File "/usr/local/lib/python2.7/dist-packages/tornado/ioloop.py", line 1073, in start
    handler_func(fd_obj, events)
  File "/usr/local/lib/python2.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/zmq/eventloop/zmqstream.py", line 450, in _handle_events
    self._handle_recv()
  File "/usr/local/lib/python2.7/dist-packages/zmq/eventloop/zmqstream.py", line 480, in _handle_recv
    self._run_callback(callback, msg)
  File "/usr/local/lib/python2.7/dist-packages/zmq/eventloop/zmqstream.py", line 432, in _run_callback
    callback(*args, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/ipykernel/kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/usr/local/lib/python2.7/dist-packages/ipykernel/kernelbase.py", line 233, in dispatch_shell
    handler(stream, idents, msg)
  File "/usr/local/lib/python2.7/dist-packages/ipykernel/kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "/usr/local/lib/python2.7/dist-packages/ipykernel/ipkernel.py", line 208, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/usr/local/lib/python2.7/dist-packages/ipykernel/zmqshell.py", line 537, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/IPython/core/interactiveshell.py", line 2714, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/usr/local/lib/python2.7/dist-packages/IPython/core/interactiveshell.py", line 2818, in run_ast_nodes
    if self.run_code(code, result):
  File "/usr/local/lib/python2.7/dist-packages/IPython/core/interactiveshell.py", line 2878, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-6-48ec0c30546c>", line 38, in <module>
    saver = tf.train.Saver()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 1281, in __init__
    self.build()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 1293, in build
    self._build(self._filename, build_save=True, build_restore=True)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 1330, in _build
    build_save=build_save, build_restore=build_restore)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 778, in _build_internal
    restore_sequentially, reshape)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 397, in _AddRestoreOps
    restore_sequentially)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 829, in bulk_restore
    return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_io_ops.py", line 1463, in restore_v2
    shape_and_slices=shape_and_slices, dtypes=dtypes, name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/util/deprecation.py", line 454, in new_func
    return func(*args, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 3155, in create_op
    op_def=op_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1717, in __init__
    self._traceback = tf_stack.extract_stack()

DataLossError (see above for traceback): Unable to open table file ./model/Fetal_2D_norm.ckpt: Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator?
	 [[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, ..., DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]


In [None]:
#######################Train###################################
for step2, (image_batch, label_batch) in enumerate(generate_batch()):            
    label_vect = np.reshape(np.argmax(label_batch, axis=-1), [batch_size * width * height])
    weight_vect = class_weights[label_vect]
    # Fit training using batch data
    feed_dict = {x: image_batch, y: label_batch, weights: weight_vect, lr:learning_rate}
    loss, acc, _ = sess.run([cost, accuracy, optimizer], feed_dict=feed_dict)
    if step2 % display_step == 0:
        print("Step %d, Minibatch Loss=%0.6f , Training Accuracy=%0.5f " 
              % (step2, loss, acc))

        # Save the variables to disk.
        saver.save(sess, model_path)
    if step2 % 2000 == 0:
        learning_rate *= 0.9