#### Install dependencies - everything else should be installed automatically by colab

In [0]:
!pip install tifffile
!pip install --upgrade tensorflow==1.13.1
# !pip install numpy==1.14.6

Collecting tifffile
[?25l  Downloading https://files.pythonhosted.org/packages/ca/96/2fcac22c806145b34e682e03874b490ae09bc3e48013a0c77e590cd6be29/tifffile-2019.7.26-py2.py3-none-any.whl (131kB)
[K     |██▌                             | 10kB 14.2MB/s eta 0:00:01[K     |█████                           | 20kB 2.3MB/s eta 0:00:01[K     |███████▌                        | 30kB 3.2MB/s eta 0:00:01[K     |██████████                      | 40kB 2.1MB/s eta 0:00:01[K     |████████████▌                   | 51kB 2.6MB/s eta 0:00:01[K     |███████████████                 | 61kB 3.1MB/s eta 0:00:01[K     |█████████████████▌              | 71kB 3.6MB/s eta 0:00:01[K     |████████████████████            | 81kB 4.1MB/s eta 0:00:01[K     |██████████████████████▌         | 92kB 4.5MB/s eta 0:00:01[K     |█████████████████████████       | 102kB 3.5MB/s eta 0:00:01[K     |███████████████████████████▌    | 112kB 3.5MB/s eta 0:00:01[K     |██████████████████████████████  | 122kB 3.5M

In [0]:
from tensorflow.keras.models import load_model
from tensorflow.keras.losses import binary_crossentropy
from tensorflow.keras.models import load_model
import numpy as np
import time
import os
import tensorflow as tf
import tensorflow.keras.backend as K
from google.colab import drive
import matplotlib.pyplot as plt
import tifffile

print('TENSORFLOW VERSION ', tf.__version__)

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


TENSORFLOW VERSION  1.13.1


In [0]:
drive.mount('/content/gdrive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive


#### Config

# **IMPORTANT**
You must modify the following `training_paths` with tuples of form `[(PATH_TO_TRAINING_X, PATH_TO_TRAINING_Y), ...]`.

You must also modify `validation_data` with your own validation data paths, in the same tuple form as above.

The config parameters must be the same as in the train notebook.

In [0]:
BATCH_SIZE = 16
PATCH_SHAPE = (12, 256, 256)
NUM_LAYERS = 4
START_CH = 32
OVERLAP_X_Y = PATCH_SHAPE[1]//4
OVERLAP_Z = PATCH_SHAPE[0]//4
DROPOUT = 0.3
model_fn = f"gdrive/My Drive/models/{NUM_LAYERS}L_{START_CH}ch_{PATCH_SHAPE}_{DROPOUT}DROPOUT iter25"


# DATA AUGMENTATION CONFIGURATION
ROTATION_RANGE = 0  # +/- ~90˚
ZOOM_RANGE = 0.05 # +/- ~10% zoom
CONTRAST_RANGE = 0.1 # +/- ~10% constrast
BRIGHTNESS_RANGE = 0.1 # +/- ~10% brightness
BLUR_RANGE = 0.2 # blur +/- 2 sigma
NOISE = 12  # range of noise +/- 12 brightness

#### Loss

In [0]:
def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_coef_loss(y_true, y_pred):
    return 1-dice_coef(y_true, y_pred)

def bce_dice_loss(y_true, y_pred):
    return binary_crossentropy(y_true, y_pred) + dice_coef_loss(y_true, y_pred)

**Data utils**

In [0]:
def get_random_batch_corner_coordinates(batch_size, region):
  """ @param region: (Z0, Z1, Y0, Y1, X0, X1) return coordniates in high/low range given
      @param batch_size: how many random coordinates to generate?

      @return: (batch_size, 3) stacks of random (Z, Y, X) coordinates
  """
  r = np.array([[np.random.randint(region[0], region[1] - PATCH_SHAPE[0]),
                 np.random.randint(region[2], region[3] - PATCH_SHAPE[1]),
                 np.random.randint(region[4], region[5] - PATCH_SHAPE[2])
                 ] for _ in range(batch_size)])
  return r


def get_image_patch(image, corner_coordinate):
  """

  :param image:
  :param corner_coordinate:
  :param patch_shape:
  :return:
  """
  patch = image[corner_coordinate[0]:corner_coordinate[0] + PATCH_SHAPE[0],
                corner_coordinate[1]:corner_coordinate[1] + PATCH_SHAPE[1],
                corner_coordinate[2]:corner_coordinate[2] + PATCH_SHAPE[2]]
  return patch

def normalize_batch(batch):
  """ normalize X batch by subbing mean then dividing by std
      normalized over the 0 axis (patch-wise mean and std)
  """
  mean = batch.mean(axis=(1, 2, 3), keepdims=True)
  std = batch.std(axis=(1, 2, 3), keepdims=True)
  batch = (batch - mean) / (std + 0.0001)
  return batch

#### Make patch-wise predictions

In [0]:
# you may want to load a different model - make sure that patch size is correct
model_filename = model_fn + '.hdf5'
print("Loading model: " + model_filename)
model = load_model(model_filename, custom_objects={'dice_coef_loss': dice_coef_loss})

model.compile(
# optimizer=tf.keras.optimizers.Adam(lr=0.001),
  optimizer=tf.train.AdamOptimizer(learning_rate=0.001),
  loss=dice_coef_loss
)


TPU_WORKER = 'grpc://' + os.environ['COLAB_TPU_ADDR']  # get TPU address
tf.logging.set_verbosity(tf.logging.INFO)

strategy = tf.contrib.tpu.TPUDistributionStrategy(
  tf.contrib.cluster_resolver.TPUClusterResolver(TPU_WORKER)
)

tpu_model = tf.contrib.tpu.keras_to_tpu_model(
  model,
  strategy=strategy, 
)

Loading model: gdrive/My Drive/models/4L_32ch_(12, 256, 256)_0.3DROPOUT iter25.hdf5
INFO:tensorflow:Querying Tensorflow master (grpc://10.127.69.194:8470) for TPU system metadata.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 7174341447447038467)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, 15389719267837346949)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 10087768657594870198)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 7588794946352429535)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179

In [0]:
def split(a, n):
  """ split a list into n chunks of approx equal length """
  k, m = divmod(len(a), n)
  return (a[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n))


def run_prediction(model, volume):
  """
  :model: a keras model object
  :volume: a 3D image volume (ndarray) to obtain the 3d segmentation map for
  """

  # padding guarantees that the whole volume is convolved over.
  padded_volume = np.pad(
      volume, 
      (
       (OVERLAP_Z + PATCH_SHAPE[0], OVERLAP_Z + PATCH_SHAPE[0]),
       (OVERLAP_X_Y + PATCH_SHAPE[1], OVERLAP_X_Y + PATCH_SHAPE[1]),
       (OVERLAP_X_Y + PATCH_SHAPE[2], OVERLAP_X_Y + PATCH_SHAPE[2])
      ),
      'symmetric'
  )
  
  D, H, W = padded_volume.shape
  print(padded_volume.shape)
  
  grid_coordinates = [
      (z, y, x) 
      for z in range(0, D-PATCH_SHAPE[0], PATCH_SHAPE[0]-(OVERLAP_Z*2))
      for y in range(0, H-PATCH_SHAPE[1], PATCH_SHAPE[1]-(OVERLAP_X_Y*2))
      for x in range(0, W-PATCH_SHAPE[2], PATCH_SHAPE[2]-(OVERLAP_X_Y*2)) 
  ]
  
  result_volume = np.zeros_like(padded_volume, dtype=np.float32)
  # split up computation into batches because thats a lot of patches
  print(f'Running prediction over {len(grid_coordinates)} patches')
  
  # make divisible by 8 for tpu
  tpu_coords = [grid_coordinates[x:x+128] for x in range(0, len(grid_coordinates)-128, 128)]
  
  for batch_coordinates in tpu_coords:
  
    image_patches = []
    for corner in batch_coordinates:
      image_patch = get_image_patch(padded_volume, corner)
      image_patches.append(image_patch)

    image_patches = np.array(image_patches)
    image_patches = np.moveaxis(image_patches, 1, 3)  # tf --> (N, W, H, D)
    normalized_image_patches = normalize_batch(image_patches)

    predictions = model.predict(normalized_image_patches)
    
    """
    print(predictions.shape)
    predictions = predictions[0]
    predictions = np.moveaxis(predictions, 2, 0)
    print(predictions.shape)
    plt.figure(figsize=(20, 20))
    plt.imshow(predictions[0]*255, cmap='gray', vmin=0, vmax=255)
    plt.show()
    return
    """

    for i, (d, h, w) in enumerate(batch_coordinates):
      # crop out the patch overlap (remove the perimiter)
      p  = np.moveaxis(predictions[i], 2, 0)[
        OVERLAP_Z : -OVERLAP_Z, 
        OVERLAP_X_Y : -OVERLAP_X_Y,
        OVERLAP_X_Y : -OVERLAP_X_Y
      ]
      # insert that crop into the right place in the result image
      result_volume[
        d + OVERLAP_Z : d + PATCH_SHAPE[0] - OVERLAP_Z, 
        h + OVERLAP_X_Y : h + PATCH_SHAPE[1] - OVERLAP_X_Y, 
        w + OVERLAP_X_Y : w + PATCH_SHAPE[2] - OVERLAP_X_Y
      ] = p
      
  
  # handle the missing patches from TPU divisible batches
  
  missing = len(grid_coordinates) % 128
  extra_batch = grid_coordinates[len(grid_coordinates)-missing:]
  extra_batch += [(0,0,0)] * (128 - missing)

  image_patches = []
  for corner in extra_batch:
    image_patch = get_image_patch(padded_volume, corner)
    image_patches.append(image_patch)

  image_patches = np.array(image_patches)
  image_patches = np.moveaxis(image_patches, 1, 3)  # tf --> (N, W, H, D)
  normalized_image_patches = normalize_batch(image_patches)

  predictions = model.predict(normalized_image_patches)

  for i, (d, h, w) in enumerate(extra_batch[:missing]):
    # crop out the patch overlap (remove the perimiter)
    p  = np.moveaxis(predictions[i], 2, 0)[
      OVERLAP_Z : -OVERLAP_Z, 
      OVERLAP_X_Y : -OVERLAP_X_Y,
      OVERLAP_X_Y : -OVERLAP_X_Y
    ]
    # insert that crop into the right place in the result image
    result_volume[
      d + OVERLAP_Z : d + PATCH_SHAPE[0] - OVERLAP_Z, 
      h + OVERLAP_X_Y : h + PATCH_SHAPE[1] - OVERLAP_X_Y, 
      w + OVERLAP_X_Y : w + PATCH_SHAPE[2] - OVERLAP_X_Y
    ] = p
    
    
  # remove the padding to restore original shape
  result_volume = result_volume[
      OVERLAP_Z + PATCH_SHAPE[0] : -(OVERLAP_Z + PATCH_SHAPE[0]),
      OVERLAP_X_Y + PATCH_SHAPE[1] : -(OVERLAP_X_Y + PATCH_SHAPE[1]),
      OVERLAP_X_Y + PATCH_SHAPE[2] : -(OVERLAP_X_Y + PATCH_SHAPE[2]),
  ]

  return result_volume

#### Run prediction on whole volume (517, 2048, 2048)

In [0]:
t = 'gdrive/My Drive/ROI_1656-6756-329.tiff'
X_test = tifffile.imread(t)
X_test.shape

In [0]:
for xx in range(0, 2048, 512):
  for yy in range(0, 2048, 512):
    start = time.time()
    patch = X_test[:, xx:xx+512, yy:yy+512]
    patch_prediction = run_prediction(tpu_model, patch)
    X_test[:, xx:xx+512, yy:yy+512] = patch_prediction
    print(f'Prediciton {xx},{yy} completed in {int(time.time()-start)} seconds')
    
from os.path import join
save_dir = 'gdrive/My Drive/Unlimited/'
save_path_raw = join(save_dir, 'WHOLE_3VIEW_EXP_L.tiff')
tifffile.imsave(save_path_raw, X_test*255)
!zip -r 'gdrive/My Drive/Unlimited/WHOLE_3VIEW_EXP_L.zip' 'gdrive/My Drive/Unlimited/WHOLE_3VIEW_EXP_L.tiff'

#### Run prediction on individual ROIs 

In [0]:
### PREDICT AND SAVE FOR ALL IN DIRECTORY
from os import listdir, remove, mkdir
from os.path import join, exists
from zipfile import ZipFile

image_dir = 'gdrive/My Drive/scaled-stacks/'
label_dir = 'gdrive/My Drive/scaled-labels-stacks/'
save_dir = 'gdrive/My Drive/predictions_individual/'

if not exists(save_dir):
  mkdir(save_dir)

# replace this with listdir(image_dir) to run on all ROIs
#paths = ['ROI_2052-5784-112.tiff']
#paths = ['roi_test.tif']
#paths = ['ROI_3588-3972-1.tiff']
paths = ['ROI_1656-6756-329.tiff', 'ROI_3624-2712-201.tiff']


for path in paths:
  read_path = join(image_dir, path)
  X_test = tifffile.imread(read_path)
  start = time.time()
  test_prediction = run_prediction(tpu_model, X_test)
  print(f'Prediciton completed in {int(time.time()-start)} seconds')
  
  save_path_raw = join(save_dir, path)
  tifffile.imsave(save_path_raw, test_prediction)
  

(283, 1140, 1140)
Running prediction over 2254 patches
INFO:tensorflow:New input shapes; (re-)compiling: mode=infer (# of cores 8), [TensorSpec(shape=(4, 256, 256, 12), dtype=tf.float32, name='input_1_50')]
INFO:tensorflow:Overriding default placeholder.
INFO:tensorflow:Remapping placeholder for input_1
INFO:tensorflow:Started compiling
INFO:tensorflow:Finished compiling. Time elapsed: 37.53628492355347 secs
INFO:tensorflow:Setting weights on TPU model.
Prediciton completed in 158 seconds
(319, 1140, 1140)
Running prediction over 2548 patches
Prediciton completed in 103 seconds


In [0]:
def overlay_x_on_y(x, y):
  colours = plt.cm.viridis(x)
  colours[..., -1] = x * 1
  fig = plt.figure(figsize=(7, 7), dpi=180)
  ax = fig.add_subplot(111)
  ax.grid(False)
  plt.imshow(y, 'gray', interpolation='bilinear')
  plt.imshow(colours, interpolation='bilinear')
  plt.show()
  

In [0]:
overlay_x_on_y(test_prediction[0, :, :], X_test[0, :, :])
#overlay_x_on_y(test_prediction[160, :, :], X_test[160, :, :])
#overlay_x_on_y(test_prediction[161, :, :], X_test[161, :, :])
#overlay_x_on_y(test_prediction[140, :, :], X_test[140, :, :])
#overlay_x_on_y(test_prediction[210, :, :], X_test[210, :, :])

In [0]:
def iou(ground_truth, prediction):
  intersection = ground_truth * ground_truth
  union = np.sum(ground_truth) + np.sum(prediction)
  score = intersection.sum() / union
  return score

for path in paths:
  read_path = join(label_dir, path)
  y_true = np.where(tifffile.imread(read_path) >= 0.5, 1, 0)
  print(iou(y_true, test_prediction))


0.2086840829512997
0.303308746522341


#### Connected component analysis for the removal of noise is best accomplished using the Fiji plugin called MorphoLibJ. 