#### Install all dependencies
Everything else should be installed automatically by colab

In [0]:
!pip install --upgrade tensorflow==1.13.1
# !pip install numpy==1.14.6
!pip install tifffile

In [0]:
import numpy as np
import csv
import six
import collections
from skimage.io import imread
import time
import os
from google.colab import drive
import random
from scipy.ndimage.interpolation import affine_transform
from scipy.ndimage.filters import gaussian_filter
import matplotlib.pyplot as plt
import tifffile

In [0]:
drive.mount('/content/gdrive')

#### Load data

In [0]:
# Using stacked tiff image as input

images_path = '/content/gdrive/My Drive/cropped-stacks/'
labels_path = '/content/gdrive/My Drive/cropped-labels-stacks/'

validation_rois = ('ROI_2052-5784-112', 'ROI_3588-3972-1')
holdout_rois = ('ROI_1656-6756-329', 'ROI_3624-2712-201', 'ROI_1716-7800-517')
extension = ".tiff"

import os
from glob import glob
import zipfile

def get_file(file_pattern):
  filenames = glob(file_pattern)
  if len(filenames) != 0:
    return filenames[0]
  return ""

def unzip_zipfile(path):
  zipfilename = get_file(path + "*.zip")
  if zipfilename:
    print("Unzipping: ", zipfilename)
    with zipfile.ZipFile(zipfilename, "r") as zip_ref:
      zip_ref.extractall(path)
    os.remove(zipfilename)

unzip_zipfile(images_path)
unzip_zipfile(labels_path)

rois = [os.path.splitext(filename)[0] for filename in os.listdir(labels_path)]

training_rois = [roi for roi in rois
                 if roi not in validation_rois
                 and roi not in holdout_rois]


print("Training data:   ", training_rois)
print("Validation data: ", validation_rois)
print("Holdout data:    ", holdout_rois)


def load_stack_pair(roi):
  image = imread(images_path + roi + extension)
  label = np.where(imread(labels_path + roi + extension) >= 0.5, 1, 0)   # binary conversion by round-off
  return [image, label]


def get_random_training_stack():
  return load_stack_pair(random.choice(training_rois))


def get_random_validation_stack():
  return load_stack_pair(random.choice(validation_rois))


def check_data():
  print("Training sets")
  for roi in training_rois:
    pair = load_stack_pair(roi)
    print(roi, pair[0].shape, pair[1].shape)

  print("Validation sets")
  for roi in validation_rois:
    pair = load_stack_pair(roi)
    print(roi, pair[0].shape, pair[1].shape)


**Data augmentation utils**

In [0]:
def transform_matrix_offset_center(matrix, x, y):
  o_x = float(x) / 2 + 0.5
  o_y = float(y) / 2 + 0.5
  offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]])
  reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]])
  transform_matrix = np.dot(np.dot(offset_matrix, matrix), reset_matrix)
  return transform_matrix


def zoom_and_rotate_patch(patch, angle, zoom):
  theta = np.deg2rad(angle)
  rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0],
                              [np.sin(theta), np.cos(theta), 0],
                              [0, 0, 1]])
  transform_matrix = rotation_matrix

  zoom_matrix = np.array([[zoom, 0, 0],
                          [0, zoom, 0],
                          [0, 0, 1]])

  transform_matrix = np.dot(transform_matrix, zoom_matrix)

  d, h, w = patch.shape
  transform_matrix = transform_matrix_offset_center(transform_matrix, h, w)
  patch = np.rollaxis(patch, 0, 0)
  final_affine_matrix = transform_matrix[:2, :2]
  final_offset = transform_matrix[:2, 2]

  channel_images = [
    affine_transform(
      x_channel,
      final_affine_matrix,
      final_offset,
      order=1,
      mode='reflect'
    ) for x_channel in patch
  ]

  patch = np.stack(channel_images, axis=0)
  patch = np.rollaxis(patch, 0, 1)

  return patch


def normalize_batch(batch):
  """ normalize X batch by subbing mean then dividing by std
      normalized over the 0 axis (patch-wise mean and std)
  """
  mean = batch.mean(axis=(1, 2, 3), keepdims=True)
  std = batch.std(axis=(1, 2, 3), keepdims=True)
  batch = (batch - mean) / (std + 0.0001)
  return batch


sometimes = lambda x: np.random.random() < x  # True x% of the time

def augment_and_normalize_batch(image_patches, label_patches):

  augmented_image_patches = []
  augmented_label_patches = []

  for i in range(len(image_patches)):
    
    img_patch = image_patches[i]
    lbl_patch = label_patches[i]
    
    if sometimes(0.5):  # 50% of the time, rotate or zoom the image
      zoom = np.random.normal(loc=1.0, scale=ZOOM_RANGE)
      angle = np.random.normal(loc=0.0, scale=ROTATION_RANGE)
      img_patch = zoom_and_rotate_patch(img_patch, angle, zoom)
      lbl_patch = zoom_and_rotate_patch(lbl_patch, angle, zoom)

    if sometimes(0.25):  # 25% of the time, blur OR noise up the image
      if sometimes(0.5):
        sigma = np.random.normal(loc=0.0, scale=BLUR_RANGE)
        img_patch = gaussian_filter(img_patch, sigma=sigma)
        # don't blur the segmentation mask
      else:
        img_patch = img_patch + np.random.normal(loc=0.0, scale=NOISE, size=img_patch.shape)
    
    if sometimes(0.25):  # 25% of the time, change brightness or contrast
      if sometimes(0.5):
        img_patch = img_patch + np.random.normal(loc=0.0, scale=BRIGHTNESS_RANGE)
      else:
        img_patch = img_patch * np.random.normal(loc=1.0, scale=CONTRAST_RANGE)
        
    augmented_image_patches.append(img_patch)
    augmented_label_patches.append(lbl_patch)

  image_patches = np.stack(augmented_image_patches, axis=0)
  label_patches = np.stack(augmented_label_patches, axis=0)
  image_patches = normalize_batch(image_patches)

  return image_patches, label_patches

#### Data generator

In [0]:
def get_random_range(begin, end):
  if end > begin:
    return np.random.randint(begin, end)
  else:
    return begin


def get_random_batch_corner_coordinates(batch_size, region):
  """ @param region: (Z0, Z1, Y0, Y1, X0, X1) return coordniates in high/low range given
      @param batch_size: how many random coordinates to generate?

      @return: (batch_size, 3) stacks of random (Z, Y, X) coordinates
  """
  

  r = np.array([[get_random_range(region[0], region[1] - PATCH_SHAPE[0]),
                 get_random_range(region[2], region[3] - PATCH_SHAPE[1]),
                 get_random_range(region[4], region[5] - PATCH_SHAPE[2])
                 ] for _ in range(batch_size)])
  return r


def get_image_patch(image, corner_coordinate):
  """

  :param image:
  :param corner_coordinate:
  :param patch_shape:
  :return:
  """
  patch = image[corner_coordinate[0]:corner_coordinate[0] + PATCH_SHAPE[0],
                corner_coordinate[1]:corner_coordinate[1] + PATCH_SHAPE[1],
                corner_coordinate[2]:corner_coordinate[2] + PATCH_SHAPE[2]]
  return patch


def get_train_generator(batch_size):
  """
  :param patch_shape:
  :return:
  """
  
  while True:
    image, label = get_random_training_stack()
    rand_rot = random.randint(0, 3)
    image = np.rot90(image, k=rand_rot, axes=(1, 2))  # randomly rotate 90
    label = np.rot90(label, k=rand_rot, axes=(1, 2))
    S = image.shape
    region = (0, S[0], 0, S[1], 0, S[2])
    corner_coords = get_random_batch_corner_coordinates(batch_size, region)
    image_patches = []
    label_patches = []
    for corner in corner_coords:
      image_patch = get_image_patch(image, corner)
      image_patches.append(image_patch)
      label_patch = get_image_patch(label, corner)
      label_patches.append(label_patch)

    image_patches = np.moveaxis(np.array(image_patches), 1, 3)  # tf --> (N, W, H, D)
    label_patches = np.moveaxis(np.array(label_patches), 1, 3)    
    
    augmented_imgs, augmented_labels = augment_and_normalize_batch(image_patches, label_patches)
    
    yield augmented_imgs, augmented_labels

def visualise_training_batch():
  data_gen = get_train_generator(1)
  image_batch,label_batch = next(data_gen)
  image = image_batch[0][:,:,0]
  label = label_batch[0][:,:,0]
  plt.imshow(image, cmap='gray')
  plt.imshow(np.ma.masked_where(label == 0, label), vmin=0, vmax=1, cmap='cool', alpha=0.5)
  plt.show()

#visualise_training_batch()


In [0]:
def get_validation_data():
  validation_x, validation_y = get_random_validation_stack()

  D, H, W = validation_x.shape
  region = (0, D, 0, H, 0, W)

  validation_x_patches = []
  validation_y_patches = []

  for i in range(4):
    corner_coords = get_random_batch_corner_coordinates(4, region)
    for corner in corner_coords:
      x_patch = get_image_patch(validation_x, corner)
      validation_x_patches.append(x_patch)
      y_patch = get_image_patch(validation_y, corner)
      validation_y_patches.append(y_patch)

  validation_x_patches = np.moveaxis(np.array(validation_x_patches), 1, 3)
  validation_y_patches = np.moveaxis(np.array(validation_y_patches), 1, 3)

  validation_x_patches = normalize_batch(validation_x_patches)

  validation_data = (validation_x_patches, validation_y_patches)
  
  return validation_data

#### Model

**TPU is extremely fast, but unfortunately it's poorly integrated with Keras. Modification of the following callbacks is necessary to enable learning rate decay and training history logging. Maybe this will be fixed with TF 2.0, in which case, replace `ReduceLROnPlateauMODIFIED` with `ReduceLROnPlateau` and `CSVLoggerMODIFIED` with `CSVLogger` in get_callbacks().**

In [0]:
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input, Conv2D, Concatenate, MaxPooling2D ,Conv2DTranspose, LeakyReLU
from tensorflow.keras.layers import UpSampling2D, Dropout, BatchNormalization
from tensorflow.keras.losses import binary_crossentropy
from tensorflow.keras.callbacks import Callback, ModelCheckpoint
from tensorflow.keras.models import load_model
from tensorflow.python.util.tf_export import tf_export
from tensorflow.python.keras.callbacks import Callback

EPOCHS = 100
STEPS_PER_EPOCH = 100
BATCH_SIZE = 16
PATCH_SHAPE = (12, 256, 256)    # (Z, Y, X) Note: X / Y appear to be used interchangebly
NUM_LAYERS = 4
START_CH = 32
OVERLAP_X_Y = PATCH_SHAPE[1]//4
OVERLAP_Z = PATCH_SHAPE[0]//4
DROPOUT = 0.3
model_name = f"gdrive/My Drive/models/{NUM_LAYERS}L_{START_CH}ch_{PATCH_SHAPE}_{DROPOUT}DROPOUT"


# DATA AUGMENTATION CONFIGURATION
ROTATION_RANGE = 0  # +/- ~90˚
ZOOM_RANGE = 0.05 # +/- ~10% zoom
CONTRAST_RANGE = 0.1 # +/- ~10% constrast
BRIGHTNESS_RANGE = 0.1 # +/- ~10% brightness
BLUR_RANGE = 0.2 # blur +/- 2 sigma
NOISE = 12  # range of noise +/- 12 brightness


# model construction


def conv_block(m, dim, bn, res, do=0):
  n = BatchNormalization()(m) if bn else m
  n = Dropout(do)(n) if do else m
  # inception
  tower_1 = Conv2D(dim, 1, padding='same', activation='linear')(n)
  tower_1 = LeakyReLU(alpha=0.1)(tower_1)
  tower_1 = Conv2D(dim, (3, 3), padding='same', activation='linear')(tower_1)
  tower_1 = LeakyReLU(alpha=0.1)(tower_1)
  tower_2 = Conv2D(dim, (1, 1), padding='same', activation='linear')(n)
  tower_2 = LeakyReLU(alpha=0.1)(tower_2)
  tower_2 = Conv2D(dim, (5, 5), padding='same', activation='linear')(tower_2)
  tower_2 = LeakyReLU(alpha=0.1)(tower_2)
  tower_3 = MaxPooling2D((3, 3), strides=(1, 1), padding='same')(n)
  tower_3 = LeakyReLU(alpha=0.1)(tower_3)
  tower_3 = Conv2D(dim, (1, 1), padding='same', activation='linear')(tower_3)
  tower_3 = LeakyReLU(alpha=0.1)(tower_3)
  n = Concatenate()([tower_1, tower_2, tower_3])

  return Concatenate()([m, n]) if res else n


def level_block(m, dim, depth, inc, do, bn, mp, up, res):
  if depth > 0:
    n = conv_block(m, dim, bn, res)
    m = MaxPooling2D()(n) if mp else Conv2D(dim, 3, strides=2, padding='same')(n)
    m = level_block(m, int(inc * dim), depth - 1, inc, do, bn, mp, up, res)
    if up:
      m = UpSampling2D()(m)
      m = Conv2D(dim, 2, activation='linear', padding='same')(m)
      m = LeakyReLU(alpha=0.1)(m)
    else:
      m = Conv2DTranspose(dim, 3, strides=2, activation='linear', padding='same')(m)
      m = LeakyReLU(alpha=0.1)(m)
    n = Concatenate()([n, m])
    m = conv_block(n, dim, bn, res)
  else:
    m = conv_block(m, dim, bn, res, do)
  return m


def UNet(img_shape, out_ch=1, start_ch=64, depth=4, inc_rate=2.,
         dropout=0.5, batchnorm=True, maxpool=True, upconv=True, residual=True):
  i = Input(shape=img_shape)
  o = level_block(i, start_ch, depth, inc_rate, dropout, batchnorm, maxpool, upconv, residual)
  o = Conv2D(out_ch, 1, activation='sigmoid')(o)
  return Model(inputs=i, outputs=o)


# loss functions


def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_coef_loss(y_true, y_pred):
    return 1-dice_coef(y_true, y_pred)

def bce_dice_loss(y_true, y_pred):
    return binary_crossentropy(y_true, y_pred) + dice_coef_loss(y_true, y_pred)


# callbacks


@tf_export('keras.callbacks.ReduceLROnPlateau')
class ReduceLROnPlateauMODIFIED(Callback):
  """Reduce learning rate when a metric has stopped improving.

  Models often benefit from reducing the learning rate by a factor
  of 2-10 once learning stagnates. This callback monitors a
  quantity and if no improvement is seen for a 'patience' number
  of epochs, the learning rate is reduced.

  Example:

  ```python
  reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2,
                                patience=5, min_lr=0.001)
  model.fit(X_train, Y_train, callbacks=[reduce_lr])
  ```

  Arguments:
      monitor: quantity to be monitored.
      factor: factor by which the learning rate will
          be reduced. new_lr = lr * factor
      patience: number of epochs with no improvement
          after which learning rate will be reduced.
      verbose: int. 0: quiet, 1: update messages.
      mode: one of {auto, min, max}. In `min` mode,
          lr will be reduced when the quantity
          monitored has stopped decreasing; in `max`
          mode it will be reduced when the quantity
          monitored has stopped increasing; in `auto`
          mode, the direction is automatically inferred
          from the name of the monitored quantity.
      min_delta: threshold for measuring the new optimum,
          to only focus on significant changes.
      cooldown: number of epochs to wait before resuming
          normal operation after lr has been reduced.
      min_lr: lower bound on the learning rate.
  """

  def __init__(self,
               monitor='val_loss',
               factor=0.1,
               patience=10,
               verbose=0,
               mode='auto',
               min_delta=1e-4,
               cooldown=0,
               min_lr=0,
               **kwargs):
    super(ReduceLROnPlateauMODIFIED, self).__init__()

    self.monitor = monitor
    if factor >= 1.0:
      raise ValueError('ReduceLROnPlateau ' 'does not support a factor >= 1.0.')
    if 'epsilon' in kwargs:
      min_delta = kwargs.pop('epsilon')
      logging.warning('`epsilon` argument is deprecated and '
                      'will be removed, use `min_delta` instead.')
    self.factor = factor
    self.min_lr = min_lr
    self.min_delta = min_delta
    self.patience = patience
    self.verbose = verbose
    self.cooldown = cooldown
    self.cooldown_counter = 0  # Cooldown counter.
    self.wait = 0
    self.best = 0
    self.mode = mode
    self.monitor_op = None
    self._reset()

  def _reset(self):
    """Resets wait counter and cooldown counter.
    """
    if self.mode not in ['auto', 'min', 'max']:
      logging.warning('Learning Rate Plateau Reducing mode %s is unknown, '
                      'fallback to auto mode.', self.mode)
      self.mode = 'auto'
    if (self.mode == 'min' or
        (self.mode == 'auto' and 'acc' not in self.monitor)):
      self.monitor_op = lambda a, b: np.less(a, b - self.min_delta)
      self.best = np.Inf
    else:
      self.monitor_op = lambda a, b: np.greater(a, b + self.min_delta)
      self.best = -np.Inf
    self.cooldown_counter = 0
    self.wait = 0

  def on_train_begin(self, logs=None):
    self._reset()

  def on_epoch_end(self, epoch, logs=None):
    logs = logs or {}
#     print("DEBUG")
#     print(self.model.optimizer.optimizer._opt._lr)
#     print(dir(self.model.optimizer.optimizer._opt._lr))
#     print(K.get_value(self.model.optimizer.optimizer._opt._lr))
    logs['lr'] = self.model.optimizer.optimizer._opt._lr
    current = logs.get(self.monitor)
    if current is None:
      logging.warning('Reduce LR on plateau conditioned on metric `%s` '
                      'which is not available. Available metrics are: %s',
                      self.monitor, ','.join(list(logs.keys())))

    else:
      if self.in_cooldown():
        self.cooldown_counter -= 1
        self.wait = 0

      if self.monitor_op(current, self.best):
        self.best = current
        self.wait = 0
      elif not self.in_cooldown():
        self.wait += 1
        if self.wait >= self.patience:
          old_lr = float(self.model.optimizer.optimizer._opt._lr)
          if old_lr > self.min_lr:
            new_lr = old_lr * self.factor
            new_lr = max(new_lr, self.min_lr)
            self.model.optimizer.optimizer._opt._lr = new_lr
            if self.verbose > 0:
              print('\nEpoch %05d: ReduceLROnPlateau reducing learning '
                    'rate to %s.' % (epoch + 1, new_lr))
            self.cooldown_counter = self.cooldown
            self.wait = 0

  def in_cooldown(self):
    return self.cooldown_counter > 0


@tf_export('keras.callbacks.CSVLogger')
class CSVLoggerMODIFIED(Callback):
  """Callback that streams epoch results to a csv file.

  Supports all values that can be represented as a string,
  including 1D iterables such as np.ndarray.

  Example:

  ```python
  csv_logger = CSVLogger('training.log')
  model.fit(X_train, Y_train, callbacks=[csv_logger])
  ```

  Arguments:
      filename: filename of the csv file, e.g. 'run/log.csv'.
      separator: string used to separate elements in the csv file.
      append: True: append if file exists (useful for continuing
          training). False: overwrite existing file,
  """

  def __init__(self, filename, separator=',', append=False):
    self.sep = separator
    self.filename = filename
    self.append = append
    self.writer = None
    self.keys = None
    self.append_header = True
    if six.PY2:
      self.file_flags = 'b'
      self._open_args = {}
    else:
      self.file_flags = ''
      self._open_args = {'newline': '\n'}
    super(CSVLoggerMODIFIED, self).__init__()

  def on_train_begin(self, logs=None):
    if self.append:
      self.mode = 'a'
    else:
      self.mode = 'w'

  def on_epoch_end(self, epoch, logs=None):
    
    with open(self.filename, 'a') as f:
      writer = csv.writer(f, delimiter=',', quotechar="'", quoting=csv.QUOTE_ALL)
      writer.writerow([logs['val_loss'], logs['loss'], logs['lr']])

  def on_train_end(self, logs=None):
    self.writer = None
    

def get_callbacks(model_name):
  callbacks = list()
  model_filename = model_name + '.hdf5'
  callbacks.append(ModelCheckpoint(
    filepath=model_filename,
    verbose=1, 
    monitor='loss',
    save_best_only=True,
    period=1,
    mode='min'
  ))

  callbacks.append(ReduceLROnPlateauMODIFIED(
    monitor='val_loss',
    factor=0.75,   # lr = lr*factor
    patience=4,  # how many epochs no change
    verbose=1
  ))

  log_filename = model_name + '.log'
  callbacks.append(CSVLoggerMODIFIED(log_filename, append=True))
    
  return callbacks


#### Model Training

In [0]:
model = UNet(
  (PATCH_SHAPE[1], PATCH_SHAPE[2], PATCH_SHAPE[0]),  # bc tf expects (W, H, D)
  start_ch=START_CH,
  depth=NUM_LAYERS,
  out_ch=PATCH_SHAPE[0],
  dropout=DROPOUT,
  residual=True,
  upconv=False
)

# Optional: Load model from last saved params
model_filename = model_name + ".hdf5"
if os.path.exists(model_filename):
  print("loading model: " + model_filename)
  model = load_model(model_filename, custom_objects={'dice_coef_loss': dice_coef_loss})

model.compile(
# optimizer=tf.keras.optimizers.Adam(lr=0.001),
  optimizer=tf.train.AdamOptimizer(learning_rate=0.0005),
  loss=dice_coef_loss
)


TPU_WORKER = 'grpc://' + os.environ['COLAB_TPU_ADDR']  # get TPU address
tf.logging.set_verbosity(tf.logging.INFO)

strategy = tf.contrib.tpu.TPUDistributionStrategy(
  tf.contrib.cluster_resolver.TPUClusterResolver(TPU_WORKER)
)

tpu_model = tf.contrib.tpu.keras_to_tpu_model(
  model,
  strategy=strategy, 
)

#print(model.summary())

tpu_model.fit_generator(
  generator=get_train_generator(BATCH_SIZE),
  steps_per_epoch=STEPS_PER_EPOCH,
  epochs=EPOCHS,
  validation_data=get_validation_data(),
  callbacks=get_callbacks(model_name),
)
