# **3D U-Net**

<font size = 4> The 3D U-Net was first introduced by [Çiçek et al](https://arxiv.org/abs/1606.06650) for learning dense volumetric segmentations from sparsely annotated ground-truth data building upon the original U-Net architecture by [Ronneberger et al](https://arxiv.org/abs/1505.04597).

---

<font size = 4>*Disclaimer*:

<font size = 4>This notebook is inspired by the *Zero-Cost Deep-Learning to Enhance Microscopy* project ([ZeroCostDL4Mic](https://github.com/HenriquesLab/ZeroCostDL4Mic/wiki)) and was created by Daniel Krentzel. The source code for this implementation can be found [here](https://github.com/krentzd/unet-3d).

<font size = 4>This notebook is laregly based on the following paper: 

<font size = 4>[**3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation**](https://arxiv.org/pdf/1606.06650.pdf), Çiçek Ö, Abdulkadir A, Lienkamp SS, Brox T, Ronneberger O. International conference on medical image computing and computer-assisted intervention 2016 Oct 17 (pp. 424-432). Springer, Cham.

<font size = 4>The following two Python libraries play an important role in the notebook: 

1. <font size = 4>[**Elasticdeform**](https://github.com/gvtulder/elasticdeform)
 by Gijs van Tulder was used to augment the 3D training data using elastic grid-based deformations as described in the original 3D U-Net paper. 

2. <font size = 4>[**Tifffile**](https://github.com/cgohlke/tifffile) by Christoph Gohlke is a great library for reading and writing TIFF files. 

<font size = 4>The [example dataset](https://www.epfl.ch/labs/cvlab/data/data-em/) represents a 5x5x5µm section taken from the CA1 hippocampus region of the brain with annotated mitochondria and was acquired by Graham Knott and Marco Cantoni at EPFL.


<font size = 4>**Please also cite the original paper and relevant Python libraries when using or developing this notebook.**

# **How to use this notebook?**

---

<font size = 4>Video describing how to use ZeroCostDL4Mic notebooks are available on youtube:
  - [**Video 1**](https://www.youtube.com/watch?v=GzD2gamVNHI&feature=youtu.be): Full run through of the workflow to obtain the notebooks and the provided test datasets as well as a common use of the notebook
  - [**Video 2**](https://www.youtube.com/watch?v=PUuQfP5SsqM&feature=youtu.be): Detailed description of the different sections of the notebook


---
###**Structure of a notebook**

<font size = 4>The notebook contains two types of cells:  

<font size = 4>**Text cells** provide information and can be modified by double-clicking the cell. You are currently reading a text cell. You can create a new one by clicking `+ Text`.

<font size = 4>**Code cells** contain code which can be modfied by selecting the cell. To execute the cell, move your cursor to the `[]`-symbol on the left side of the cell (a play button should appear). Click it to execute the cell. Once the cell is fully executed, the animation stops. You can create a new coding cell by clicking `+ Code`.

---
###**Table of contents, Code snippets** and **Files**

<font size = 4>Three tabs are located on the upper left side of the notebook:

1. <font size = 4>*Table of contents* contains the structure of the notebook. Click the headers to move quickly between sections.

2. <font size = 4>*Code snippets* provides a wide array of example code specific to Google Colab. You can ignore this when using this notebook.

3. <font size = 4>*Files* displays the current working directory. We will mount your Google Drive in Section 1.2. so that you can access your files and save them permanently.

<font size = 4>**Important:** All uploaded files are purged once the runtime ends.

<font size = 4>**Note:** The directory *sample data* in *Files* contains default files. Do not upload anything there!

---
###**Making changes to the notebook**

<font size = 4>**You can make a copy** of the notebook and save it to your Google Drive by clicking *File* -> *Save a copy in Drive*.

<font size = 4>To **edit a cell**, double click on the text. This will either display the source code (in code cells) or the [markdown](https://colab.research.google.com/notebooks/markdown_guide.ipynb#scrollTo=70pYkR9LiOV0) (in text cells).
You can use `#` in code cells to comment out parts of the code. This allows you to keep the original piece of code while not executing it.

#**0. Before getting started**
---

<font size = 4>As the network operates in three dimensions, certain consideration should be given to correctly pre-processing the data. Ensure that the structure of interest does not substantially change between slices - image volumes with isotropic pixelsizes are ideal for this architecture.

<font size = 4>Each image volume must be provided as a **multipage TIFF file** to maintain the correct ordering of individual image slices. If more than one image volume has been annotated, source and target files must be named identically and placed in separate directories. In case only one image volume has been annotated, source and target file do not have to be placed in separate directories and can be named differently, as long as their paths are explicitly provided in Section 3. 

<font size = 4>**Prepare two datasets** (*training* and *testing*) for quality control puproses. Make sure that the *testing* dataset does not overlap with the *training* dataset and is ideally sourced from a different acquisiton and sample to ensure robustness of the trained model. 


---


### **Directory structure**

<font size = 4>Make sure to adhere to one of the following directory structures. If only one annotated training volume exists, choose the first structure. In case more than one training volume is available, choose the second structure.

<font size = 4>**Structure 1:** Only one training volume
```
path/to/directory/with/one/training/volume
│--training_source.tif
│--training_target.tif
|   
│--testing_source.tif
|--testing_target.tif 
|
|--data_to_predict_on.tif
|--prediction_results.tif

```
<font size = 4>**Structure 2:** Various training volumes
```
path/to/directory/with/various/training/volumes
│--testing_source.tif
|--testing_target.tif 
|
└───training
|   └───source
|   |   |--training_volume_one.tif
|   |   |--training_volume_two.tif
|   |   |--...
|   |   |--training_volume_n.tif
|   |
|   └───target
|       |--training_volume_one.tif
|       |--training_volume_two.tif
|       |--...
|       |--training_volume_n.tif
|
|--data_to_predict_on.tif
|--prediction_results.tif
```
<font size = 4>**Note:** Naming directories is completely up to you, as long as the paths are correctly specified throughout the notebook.


---


### **Important note**

* <font size = 4>If you wish to **Train a network from scratch** using your own dataset (and we encourage everyone to do so), you will need to run **Sections 1 - 4**, then use **Section 5** to assess the quality of your model and **Section 6** to run predictions using the model that you trained.

* <font size = 4>If you wish to **Evaluate your model** using a model previously generated and saved on your Google Drive, you will only need to run **Sections 1 and 2** to set up the notebook, then use **Section 5** to assess the quality of your model.

* <font size = 4> If you only wish to **Run predictions** using a model previously generated and saved on your Google Drive, you will only need to run **Sections 1 and 2** to set up the notebook, then use **Section 6** to run the predictions on the desired model.
---

In [None]:
#@markdown ##**Download example dataset**

#@markdown <font size = 4> This usually takes a few minutes. The images are saved in *example_dataset*.

import requests  
import os
from tqdm.notebook import tqdm 

def make_directory(dir):
    if not os.path.exists(dir):
        os.makedirs(dir)

def download_from_url(url, save_as):
    file_url = url
    r = requests.get(file_url, stream=True)  
  
    with open(save_as, 'wb') as file:  
        for block in tqdm(r.iter_content(chunk_size = 1024), desc = 'Downloading ' + os.path.basename(save_as),  total=126875, ncols=1000):
            if block:
                file.write(block)  


make_directory('example_dataset')

download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/training.tif', 'example_dataset/training.tif')
download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/training_groundtruth.tif', 'example_dataset/training_groundtruth.tif')
download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/testing.tif', 'example_dataset/testing.tif')
download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/testing_groundtruth.tif', 'example_dataset/testing_groundtruth.tif')

print('Example dataset successfully downloaded!')

# **1. Initialise the Colab session**
---







## **1.1. Check GPU access and Python version**
---

<font size = 4>By default, Colab sessions run Python 3 with GPU acceleration. You can manually set this by:

1. <font size = 4>Going to **Runtime -> Change runtime type**

2. <font size = 4>**Runtime type: Python 3** *(This notebook uses Python 3)*

3. <font size = 4>**Accelator: GPU** *(Graphics Processing Unit)*


In [None]:
#@markdown ##Run this cell to check if you have GPU access
%tensorflow_version 1.x

import tensorflow as tf
if tf.test.gpu_device_name()=='':
  print('You do not have GPU access.') 
  print('Did you change your runtime ?') 
  print('If the runtime setting is correct then Google did not allocate a GPU for your session')
  print('Expect slow performance. To access GPU try reconnecting later')

else:
  print('You have GPU access')
  !nvidia-smi


## **1.2. Mount Google Drive**
---
<font size = 4> To use this notebook with your **own data**, place it in a folder on **Google Drive** following one of the directory structures outlined in **Section 0**.

1. <font size = 4> **Run** the **cell** below to mount your Google Drive and follow the link. 

2. <font size = 4>**Sign in** to your Google account and press 'Allow'. 

3. <font size = 4>Next, copy the **authorization code**, paste it into the cell and press enter. This will allow Colab to read and write data from and to your Google Drive. 

4. <font size = 4> Once this is done, your data can be viewed in the **Files tab** on the top left of the notebook after hitting 'Refresh'.

In [None]:
#@markdown ##Run this cell to connect your Google Drive to Colab

from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
#@markdown ##Unzip pre-trained model directory

#@markdown 1.  <font size = 4>Upload a zipped model directory using the *Files* tab
#@markdown 2.  <font size = 4>Run this cell to unzip your model file
#@markdown 3.  <font size = 4>The model directory will appear in the *Files* tab 

from google.colab import files

zipped_model_file = "" #@param {type:"string"}

!unzip $zipped_model_file

# **2. Install 3D U-Net dependencies**
---


In [None]:
#@markdown ##Install dependencies and instantiate network

#Put the imported code and libraries here
from __future__ import absolute_import, division, print_function, unicode_literals

try:
    import elasticdeform
except:
    !pip install elasticdeform
    import elasticdeform

try:
    import tifffile
except:
    !pip install tifffile
    import tifffile

import os
import csv
import random
import h5py
import imageio
import math
import shutil

import pandas as pd
from glob import glob
from tqdm.notebook import tqdm

from skimage import transform
from skimage import exposure
from skimage import color

from scipy.ndimage import zoom

import matplotlib.pyplot as plt

import numpy as np
import tensorflow as tf

from keras import backend as K

from keras.layers import Conv3D
from keras.layers import BatchNormalization
from keras.layers import ReLU
from keras.layers import MaxPooling3D
from keras.layers import Conv3DTranspose
from keras.layers import Input
from keras.layers import Concatenate
from keras.models import Model
from keras.utils import Sequence
from keras.callbacks.callbacks import ModelCheckpoint
from keras.callbacks.callbacks import CSVLogger
from keras.callbacks.callbacks import Callback

from ipywidgets import interact
from ipywidgets import interactive
from ipywidgets import fixed
from ipywidgets import interact_manual 
import ipywidgets as widgets

print("Depencies installed and imported.")

# Define MultiPageTiffGenerator class
class MultiPageTiffGenerator(Sequence):

    def __init__(self,
                 source_path,
                 target_path,
                 batch_size=1,
                 shape=(128,128,32,1),
                 augment=True,
                 val_split=0.2,
                 is_val=False,
                 random_crop=True,
                 downscale=1):

        # If directory with various multi-page tiffiles is provided read as list
        if os.path.isfile(source_path):
            self.dir_flag = False
            self.source = tifffile.imread(source_path)
            self.target = tifffile.imread(target_path).astype(np.bool)

        elif os.path.isdir(source_path):
            self.dir_flag = True
            self.source_dir_list = glob(os.path.join(source_path, '*'))
            self.target_dir_list = glob(os.path.join(target_path, '*'))

            self.source_dir_list.sort()
            self.target_dir_list.sort()

        self.shape = shape
        self.batch_size = batch_size
        self.augment = augment
        self.val_split = val_split
        self.is_val = is_val
        self.random_crop = random_crop
        self.downscale = downscale

        self.on_epoch_end()

    def __len__(self):
        # If various multi-page tiff files provided sum all images within each
        # Expected number of non-augmented images is 1/3 of entire training set, hence multiply lenght by 3
        if self.augment:
            augment_factor = 3
        else:
            augment_factor = 1
    
        if self.dir_flag:
            num_of_imgs = 0
            for tiff_path in self.source_dir_list:
                num_of_imgs += tifffile.imread(tiff_path).shape[0]
            xy_shape = tifffile.imread(self.source_dir_list[0]).shape[1:]

            if self.is_val:
                if self.random_crop:
                    crop_volume = self.shape[0] * self.shape[1] * self.shape[2]
                    volume = xy_shape[0] * xy_shape[1] * self.val_split * num_of_imgs
                    return math.floor(augment_factor * crop_volume / (volume * self.batch_size))
                else:
                    return math.floor(self.val_split * num_of_imgs / self.batch_size)
            else:
                if self.random_crop:
                    crop_volume = self.shape[0] * self.shape[1] * self.shape[2]
                    volume = xy_shape[0] * xy_shape[1] * (1 - self.val_split) * num_of_imgs
                    return math.floor(augment_factor * crop_volume / (volume * self.batch_size))

                else:
                    return math.floor(augment_factor*(1 - self.val_split) * num_of_imgs/self.batch_size)
        else:
            if self.is_val:
                if self.random_crop:
                    crop_volume = self.shape[0] * self.shape[1] * self.shape[2]
                    volume = self.source.shape[0] * self.source.shape[1] * self.val_split * self.source.shape[2]
                    return math.floor(augment_factor * volume / (crop_volume * self.batch_size))
                else:
                    return math.floor((self.val_split * self.source.shape[0] / self.batch_size))
            else:
                if self.random_crop:
                    crop_volume = self.shape[0] * self.shape[1] * self.shape[2]
                    volume = self.source.shape[0] * self.source.shape[1] * (1 - self.val_split) * self.source.shape[2]
                    return math.floor(augment_factor * volume / (crop_volume * self.batch_size))
                else:
                    return math.floor(augment_factor * (1 - self.val_split) * self.source.shape[0] / self.batch_size)

    def __getitem__(self, idx):

        source_batch = np.empty((self.batch_size,
                                 self.shape[0],
                                 self.shape[1],
                                 self.shape[2],
                                 self.shape[3]))
        target_batch = np.empty((self.batch_size,
                                 self.shape[0],
                                 self.shape[1],
                                 self.shape[2],
                                 self.shape[3]))

        for batch in range(self.batch_size):
            # Modulo operator ensures IndexError is avoided
            stack_start = self.batch_list[(idx+batch*self.shape[2])%len(self.batch_list)]

            if self.dir_flag:
                self.source = tifffile.imread(self.source_dir_list[stack_start[0]])
                self.target = tiffile.imread(self.target_dir_list[stack_start[0]]).astype(np.bool)

            src_list = []
            tgt_list = []
            for i in range(stack_start[1], stack_start[1]+self.shape[2]):
                src = self.source[i]
                src = transform.downscale_local_mean(src, (self.downscale, self.downscale))
                if not self.random_crop:
                    src = transform.resize(src, (self.shape[0], self.shape[1]), mode='constant', preserve_range=True)
                src = src/255
                src_list.append(src)

                tgt = self.target[i]
                tgt = transform.downscale_local_mean(tgt, (self.downscale, self.downscale))
                if not self.random_crop:
                    tgt = transform.resize(tgt, (self.shape[0], self.shape[1]), mode='constant', preserve_range=True)
                tgt_list.append(tgt)

            if self.random_crop:
                if src.shape[0] == self.shape[0]:
                    x_rand = 0
                if src.shape[1] == self.shape[1]:
                    y_rand = 0
                if src.shape[0] > self.shape[0]:
                    x_rand = np.random.randint(src.shape[0] - self.shape[0])
                if src.shape[1] > self.shape[1]:
                    y_rand = np.random.randint(src.shape[1] - self.shape[1])
                if src.shape[0] < self.shape[0] or src.shape[1] < self.shape[1]:
                    raise ValueError('Patch shape larger than (downscaled) source shape')
            
            for i in range(self.shape[2]):
                if self.random_crop:
                    src = src_list[i]
                    tgt = tgt_list[i]
                    src_crop = src[x_rand:self.shape[0]+x_rand, y_rand:self.shape[1]+y_rand]
                    tgt_crop = tgt[x_rand:self.shape[0]+x_rand, y_rand:self.shape[1]+y_rand]
                else:
                    src_crop = src_list[i]
                    tgt_crop = tgt_list[i]

                source_batch[batch,:,:,i,0] = src_crop
                target_batch[batch,:,:,i,0] = tgt_crop

        if self.augment:
            # On-the-fly data augmentation
            rand = np.random.random()
            # Data augmentation by reversing stack
            if rand > 2/3:
                source_batch_rev = source_batch[::-1]
                target_batch_rev = target_batch[::-1]

                return source_batch_rev, target_batch_rev
            # Data augmentation by elastic deformation
            elif rand < 2/3 and rand > 1/3:
                [source_batch_deform, target_batch_deform] = elasticdeform.deform_random_grid([source_batch, target_batch],
                                                                                              axis=(1, 2, 3),
                                                                                              sigma=5,
                                                                                              points=3,
                                                                                              order=4) # points=2 is better imo
                target_batch_deform_bin = target_batch_deform > 0.25

                return source_batch_deform, target_batch_deform_bin
            else:
                return source_batch, target_batch
        else:
            return source_batch, target_batch

    def on_epoch_end(self):
        # Validation split performed here
        self.batch_list = []
        # Create batch_list of all combinations of tifffile and stack position
        if self.dir_flag:
            for i in range(len(self.source_dir_list)):
                num_of_pages = tifffile.imread(self.source_dir_list[i]).shape[0]
                if self.is_val:
                    start_page = num_of_pages-math.floor(self.val_split*num_of_pages)
                    for j in range(start_page, num_of_pages-self.shape[2]):
                      self.batch_list.append([i, j])
                else:
                    last_page = math.floor((1-self.val_split)*num_of_pages)
                    for j in range(last_page-self.shape[2]):
                        self.batch_list.append([i, j])
        else:
            num_of_pages = self.source.shape[0]
            if self.is_val:
                start_page = num_of_pages-math.floor(self.val_split*num_of_pages)
                for j in range(start_page, num_of_pages-self.shape[2]):
                    self.batch_list.append([0, j])

            else:
                last_page = math.floor((1-self.val_split)*num_of_pages)
                for j in range(last_page-self.shape[2]):
                    self.batch_list.append([0, j])
        
        if self.is_val and (len(self.batch_list) <= 0):
            raise ValueError('validation_split too small! Increase val_split or decrease z-depth')
        random.shuffle(self.batch_list)

    def class_weights(self):

        ones = 0
        pixels = 0

        if self.dir_flag:
            for i in range(len(self.target_dir_list)):
                tgt = tifffile.imread(self.target_dir_list[i]).astype(np.bool)
                ones += np.sum(tgt)
                pixels += tgt.shape[0]*tgt.shape[1]*tgt.shape[2]
        else:
          ones = np.sum(self.target)
          pixels = self.target.shape[0]*self.target.shape[1]*self.target.shape[2]
        p_ones = ones/pixels
        p_zeros = 1-p_ones

        # Return swapped probability to increase weight of unlikely class
        return p_ones, p_zeros

# Define custom loss and dice coefficient
def dice_coefficient(y_true, y_pred):

    eps = 1e-6
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f*y_pred_f)

    return (2.*intersection)/(K.sum(y_true_f*y_true_f)+K.sum(y_pred_f*y_pred_f)+eps)

def weighted_binary_crossentropy(zero_weight, one_weight):

    def _weighted_binary_crossentropy(y_true, y_pred):

        binary_crossentropy = K.binary_crossentropy(y_true, y_pred)

        weight_vector = y_true*one_weight+(1.-y_true)*zero_weight
        weighted_binary_crossentropy = weight_vector*binary_crossentropy

        return K.mean(weighted_binary_crossentropy)

    return _weighted_binary_crossentropy

# Custom callback showing sample prediction
class SampleImageCallback(Callback):

    def __init__(self, model, sample_data, model_path, save=False):
        self.model = model
        self.sample_data = sample_data
        self.model_path = model_path
        self.save = save

    def on_epoch_end(self, epoch, logs={}):

        sample_predict = self.model.predict_on_batch(self.sample_data)

        f=plt.figure(figsize=(16,8))
        plt.subplot(1,2,1)
        plt.imshow(self.sample_data[0,:,:,0,0], interpolation='nearest', cmap='gray')
        plt.title('Sample source')
        plt.axis('off');

        plt.subplot(1,2,2)
        plt.imshow(sample_predict[0,:,:,0,0], interpolation='nearest', cmap='magma')
        plt.title('Predicted target')
        plt.axis('off');

        plt.show()

        if self.save:
            plt.savefig(self.model_path + '/epoch_' + str(epoch+1) + '.png')


# Define Unet3D class
class Unet3D:

    def __init__(self,
                 shape=(256,256,16,1)):

        if isinstance(shape, str):
            shape = eval(shape)

        self.shape = shape
        
        input_tensor = Input(self.shape, name='input')

        self.model = self.unet_3D(input_tensor)

    def down_block_3D(self, input_tensor, filters):

        x = Conv3D(filters=filters, kernel_size=(3,3,3), padding='same')(input_tensor)
        x = BatchNormalization()(x)
        x = ReLU()(x)

        x = Conv3D(filters=filters*2, kernel_size=(3,3,3), padding='same')(x)
        x = BatchNormalization()(x)
        x = ReLU()(x)

        return x

    def up_block_3D(self, input_tensor, concat_layer, filters):

        x = Conv3DTranspose(filters, kernel_size=(2,2,2), strides=(2,2,2))(input_tensor)

        x = Concatenate()([x, concat_layer])

        x = Conv3D(filters=filters, kernel_size=(3,3,3), padding='same')(x)
        x = BatchNormalization()(x)
        x = ReLU()(x)

        x = Conv3D(filters=filters*2, kernel_size=(3,3,3), padding='same')(x)
        x = BatchNormalization()(x)
        x = ReLU()(x)

        return x

    def unet_3D(self, input_tensor, filters=32):

        d1 = self.down_block_3D(input_tensor, filters=filters)
        p1 = MaxPooling3D(pool_size=(2,2,2), strides=(2,2,2), data_format='channels_last')(d1)
        d2 = self.down_block_3D(p1, filters=filters*2)
        p2 = MaxPooling3D(pool_size=(2,2,2), strides=(2,2,2), data_format='channels_last')(d2)
        d3 = self.down_block_3D(p2, filters=filters*4)
        p3 = MaxPooling3D(pool_size=(2,2,2), strides=(2,2,2), data_format='channels_last')(d3)

        d4 = self.down_block_3D(p3, filters=filters*8)

        u1 = self.up_block_3D(d4, d3, filters=filters*4)
        u2 = self.up_block_3D(u1, d2, filters=filters*2)
        u3 = self.up_block_3D(u2, d1, filters=filters)

        output_tensor = Conv3D(filters=1, kernel_size=(1,1,1), activation='sigmoid')(u3)

        return Model(inputs=[input_tensor], outputs=[output_tensor])

    def summary(self):
        return self.model.summary()

    def train(self, 
              epochs, 
              batch_size, 
              train_source, 
              train_target, 
              model_path, 
              model_name, 
              val_split=0.2, 
              augment=True, 
              ckpt_period=1, 
              save_best_ckpt_only=False, 
              ckpt_path=None,
              random_crop=True,
              downscaling=1):

        train_generator = MultiPageTiffGenerator(train_source,
                                                 train_target,
                                                 batch_size=batch_size,
                                                 shape=self.shape,
                                                 augment=augment,
                                                 val_split=val_split,
                                                 random_crop=random_crop,
                                                 downscale=downscaling)

        val_generator = MultiPageTiffGenerator(train_source,
                                               train_target,
                                               batch_size=batch_size,
                                               shape=self.shape,
                                               augment=False,
                                               val_split=val_split,
                                               is_val=True,
                                               random_crop=random_crop,
                                               downscale=downscaling)

        class_weight_zero, class_weight_one = train_generator.class_weights()

        self.model.compile(optimizer='adam',
                           loss=weighted_binary_crossentropy(class_weight_zero, class_weight_one),
                           metrics=[dice_coefficient])

        if ckpt_path is not None:
            self.model.load_weights(ckpt_path)

        full_model_path = os.path.join(model_path, model_name)

        if not os.path.exists(full_model_path):
            os.makedirs(full_model_path)
        
        log_dir = full_model_path + '/Quality Control'

        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
        
        ckpt_dir =  full_model_path + '/ckpt'

        if not os.path.exists(ckpt_dir):
            os.makedirs(ckpt_dir)

        csv_out_name = log_dir + '/training_evaluation.csv'
        if ckpt_path is None:
            csv_logger = CSVLogger(csv_out_name)
        else:
            csv_logger = CSVLogger(csv_out_name, append=True)

        if save_best_ckpt_only:
            ckpt_name = ckpt_dir + '/' + model_name + '.hdf5'
        else:
            ckpt_name = ckpt_dir + '/' + model_name + '_epoch_{epoch:02d}_val_loss_{val_loss:.4f}.hdf5'
        
        model_ckpt = ModelCheckpoint(ckpt_name,
                                     verbose=1,
                                     period=ckpt_period,
                                     save_best_only=save_best_ckpt_only,
                                     save_weights_only=True)

        sample_batch = val_generator[0][0]
        sample_img = SampleImageCallback(self.model, 
                                         sample_batch, 
                                         model_path)

        self.model.fit_generator(generator=train_generator,
                                 validation_data=val_generator,
                                 validation_steps=math.floor(len(val_generator)/batch_size),
                                 epochs=epochs,
                                 callbacks=[csv_logger,
                                            model_ckpt,
                                            sample_img])

        last_ckpt_name = ckpt_dir + '/' + model_name + '_last.hdf5'
        self.model.save_weights(last_ckpt_name)

    def predict(self, input, ckpt_path, z_range=None, downscaling=None, true_patch_size=None):

        self.model.load_weights(ckpt_path)

        if isinstance(downscaling, str):
            downscaling = eval(downscaling)

        if math.isnan(downscaling):
            downscaling = None

        if isinstance(true_patch_size, str):
            true_patch_size = eval(true_patch_size)
        
        if not isinstance(true_patch_size, tuple): 
            if math.isnan(true_patch_size):
                true_patch_size = None

        if isinstance(input, str):
            src_volume = tifffile.imread(input)
        elif isinstance(input, np.ndarray):
            src_volume = input
        else:
            raise TypeError('Input is not path or numpy array!')
        
        in_size = src_volume.shape

        if downscaling or true_patch_size is not None:
            x_scaling = 0
            y_scaling = 0

            if true_patch_size is not None:
                x_scaling += true_patch_size[0]/self.shape[0]
                y_scaling += true_patch_size[1]/self.shape[1]
            if downscaling is not None:
                x_scaling += downscaling
                y_scaling += downscaling

            src_list = []
            for i in range(src_volume.shape[0]):
                 src_list.append(transform.downscale_local_mean(src_volume[i], (int(x_scaling), int(y_scaling))))
            src_volume = np.array(src_list)          

        if z_range is not None:
            src_volume = src_volume[z_range[0]:z_range[1]]

        src_volume = src_volume/255
       
        src_array = np.zeros((1,
                              math.ceil(src_volume.shape[1]/self.shape[0])*self.shape[0], 
                              math.ceil(src_volume.shape[2]/self.shape[1])*self.shape[1],
                              math.ceil(src_volume.shape[0]/self.shape[2])*self.shape[2], 
                              self.shape[3]))

        for i in range(src_volume.shape[0]):
            src_array[0,:src_volume.shape[1],:src_volume.shape[2],i,0] = src_volume[i]

        pred_array = np.empty(src_array.shape)

        for i in range(math.ceil(src_volume.shape[1]/self.shape[0])):
          for j in range(math.ceil(src_volume.shape[2]/self.shape[1])):
            for k in range(math.ceil(src_volume.shape[0]/self.shape[2])):
                pred_temp = self.model.predict(src_array[:,
                                                         i*self.shape[0]:i*self.shape[0]+self.shape[0],
                                                         j*self.shape[1]:j*self.shape[1]+self.shape[1],
                                                         k*self.shape[2]:k*self.shape[2]+self.shape[2]])
                pred_array[:,
                           i*self.shape[0]:i*self.shape[0]+self.shape[0],
                           j*self.shape[1]:j*self.shape[1]+self.shape[1],
                           k*self.shape[2]:k*self.shape[2]+self.shape[2]] = pred_temp
                           
        pred_volume = np.rollaxis(np.squeeze(pred_array), -1)[:src_volume.shape[0],:src_volume.shape[1],:src_volume.shape[2]]            

        if downscaling is not None:
            pred_list = []
            for i in range(pred_volume.shape[0]):
                 pred_list.append(transform.resize(pred_volume[i], (in_size[1], in_size[2]), preserve_range=True))
            pred_volume = np.array(pred_list)

        return pred_volume



# **3. Select your model parameters**

---


## **Paths to training data and model**

* <font size = 4>**`training_source`** and **`training_target`** specify the paths to the training data. They can either be a single multipage TIFF file each or directories containing various multipage TIFF files in which case target and source files must be named identically within the respective directories. See Section 0 for a detailed description of the necessary directory structure.

* <font size = 4>**`model_name`** will be used when naming checkpoints. Adhere to a `lower_case_with_underscores` naming convention and beware of using the name of an existing model within the same folder, as it will be overwritten.

* <font size = 4>**`model_path`** specifies the directory where the model checkpoints and quality control logs will be saved.


<font size = 4>**Note:** You can copy paths from the 'Files' tab by right-clicking any folder or file and selecting 'Copy path'. 

## **Training parameters**

* <font size = 4>**`number_of_epochs`** is the number of times the entire training data will be seen by the model. *Default: >100*

* <font size = 4>**`batch_size`** is the number of training patches of size `patch_size` that will be bundled together at each training step. *Deafult: 1*

* <font size = 4>**`patch_size`** specifies the size of the three-dimensional training patches that will be fed to the model. In order to avoid errors, preferably use a square aspect ratio or stick to the advanced parameters. *Default: <(512, 512, 16)*

* <font size = 4>**`validation_split_in_percent`** is the relative amount of training data that will be set aside for validation. *Default: 20* 

* <font size = 4>**`downscaling_in_xy`** downscales the training images by the specified amount in x and y. This is useful to enforce isotropic pixel-size if the z resolution is lower than the xy resolution in the training volume or to capture a larger field-of-view while decreasing the memory requirements. *Default: 1*

* <font size = 4>**`image_pre_processing`** selects whether the training images are randomly cropped during training or resized to `patch_size`. Choose `randomly crop to patch_size` to shrink the field-of-view of the training images to the `patch_size`. *Default: resize to patch_size* 

<font size = 4>**Note:** If a *ResourceExhaustedError* is raised in Section 4.1. during training, decrease `batch_size` and `patch_size`. Decrease `batch_size` first and if the error persists at `batch_size = 1`, reduce the `patch_size`.  


## **Data augmentation**
 
* <font size = 4>**`apply_data_augmentation`** ensures that data augmentation is randomly applied to the training data at each training step. This includes inverting the order of the slices within a training patch, as well as applying elastic grid-based deformations as described in the original 3D U-Net paper. Augmenting the training data increases robustness of the model by simulating possible variations within the training data which avoids it from overfitting on small datasets. We therefore strongly recommended selecting data augmentation. *Default: True*

<font size = 4>**Note:** The number of steps per epoch are calculated as `floor(augment_factor * (1 - validation_split) * num_of_slices / batch_size)` if `image_pre_processing` is `resize to patch_size` where `augment_factor` is three if `apply_data_augmentation` is `True` and one otherwise. The `num_of_slices` is the overall number of slices (z-depth) in the training set across all provided image volumes. If `image_pre_processing` is `randomly crop to patch_size`, the number of steps per epoch are calculated as `floor(augment_factor * volume / (crop_volume * batch_size))` where `volume` is the overall volume of the training data in pixels accounting for the validation split and `crop_volume` is defined as the volume in pixels based on the specified `patch_size`.

In [None]:
class bcolors:
  WARNING = '\033[31m'

#@markdown ###Path to training data:
training_source = "" #@param {type:"string"}
training_target = "" #@param {type:"string"}

#@markdown ---

#@markdown ###Model name and path to model folder:
model_name = "" #@param {type:"string"}
model_path = "" #@param {type:"string"}

full_model_path = os.path.join(model_path, model_name)

#@markdown ---

#@markdown ###Training parameters
number_of_epochs =   100#@param {type:"number"}

#@markdown ###Default advanced parameters
use_default_advanced_parameters = True #@param {type:"boolean"}

#@markdown <font size = 3>If not, please change:

batch_size =  1#@param {type:"number"}
patch_size = (512,512,16) #@param {type:"number"} # in pixels
training_shape = patch_size + (1,)
image_pre_processing = 'resize to patch_size' #@param ["randomly crop to patch_size", "resize to patch_size"]

validation_split_in_percent = 20 #@param{type:"number"}
downscaling_in_xy = 1 #@param {type:"number"} # in pixels


if image_pre_processing == "randomly crop to patch_size":
    random_crop = True
else:
    random_crop = False

#@markdown ---

#@markdown ###Checkpointing

checkpointing_period = 1 #@param {type:"number"}

#@markdown  <font size = 3>If chosen only the best checkpoint is saved, otherwise a checkpoint is saved every checkpoint_period epochs:
save_best_only = True #@param {type:"boolean"}

#@markdown <font size = 3>Choose if training was interrupted:
resume_training = False #@param {type:"boolean"}

#@markdown <font size = 3>For transfer learning, do not select resume_training and specify a checkpoint_path below:
checkpoint_path = "" #@param {type:"string"}

if resume_training and checkpoint_path != "":
    print('If resume_training is True while checkpoint_path is specified, resume_training will be set to False!')
    resume_training = False

#@markdown ---

#@markdown ###Data Augmentation

apply_data_augmentation = True #@param {type:"boolean"}
 

# Retrieve last checkpoint
if resume_training:
    try:
      ckpt_dir_list = glob(full_model_path + '/ckpt/*')
      ckpt_dir_list.sort()
      last_ckpt_path = ckpt_dir_list[-1]
      print('Training will resume from checkpoint:', os.path.basename(last_ckpt_path))
    except IndexError:
      last_ckpt_path=None
      print('CheckpointError: No previous checkpoints were found, training from scratch.')
elif not resume_training and checkpoint_path != "":
    last_ckpt_path = checkpoint_path
    assert os.path.isfile(last_ckpt_path), 'checkpoint_path does not exist!'
else:
    last_ckpt_path=None


if use_default_advanced_parameters: 
    print("Default advanced parameters enabled")
    batch_size = 1
    training_shape = (256,256,8,1)
    validation_split_in_percent = 20
    downscaling_in_xy = 1
    random_crop = False

# Instantiate Unet3D 
model = Unet3D(shape=training_shape)

#here we check that no model with the same name already exist, if so delete
if not resume_training and os.path.exists(full_model_path):
    print('!! WARNING: Folder already exists and will be overwritten !!') 
    shutil.rmtree(full_model_path)

if not os.path.exists(full_model_path):
    os.makedirs(full_model_path)

# Show sample image
if os.path.isdir(training_source):
    training_source = glob.glob(os.path.join(training_source, '*'))[0]
    training_target = glob.glob(os.path.join(training_target, '*'))[0]

src_sample = tifffile.imread(training_source)
src_sample = src_sample/255
tgt_sample = tifffile.imread(training_target).astype(np.bool)

src_down = transform.downscale_local_mean(src_sample[0], (downscaling_in_xy, downscaling_in_xy))
tgt_down = transform.downscale_local_mean(tgt_sample[0], (downscaling_in_xy, downscaling_in_xy))   

if random_crop:
    true_patch_size = None

    if src_down.shape[0] == training_shape[0]:
      x_rand = 0
    if src_down.shape[1] == training_shape[1]:
      y_rand = 0
    if src_down.shape[0] > training_shape[0]:
      x_rand = np.random.randint(src_down.shape[0] - training_shape[0])
    if src_down.shape[1] > training_shape[1]:
      y_rand = np.random.randint(src_down.shape[1] - training_shape[1])
    if src_down.shape[0] < training_shape[0] or src_down.shape[1] < training_shape[1]:
      raise ValueError('Patch shape larger than (downscaled) source shape')
else:
    true_patch_size = src_down.shape

def scroll_in_z(z):
    src_down = transform.downscale_local_mean(src_sample[z-1], (downscaling_in_xy,downscaling_in_xy))
    tgt_down = transform.downscale_local_mean(tgt_sample[z-1], (downscaling_in_xy,downscaling_in_xy))       
    if random_crop:
        src_slice = src_down[x_rand:training_shape[0]+x_rand, y_rand:training_shape[1]+y_rand]
        tgt_slice = tgt_down[x_rand:training_shape[0]+x_rand, y_rand:training_shape[1]+y_rand]
    else:
        
        src_slice = transform.resize(src_down, (training_shape[0], training_shape[1]), mode='constant', preserve_range=True)
        tgt_slice = transform.resize(tgt_down, (training_shape[0], training_shape[1]), mode='constant', preserve_range=True)

    f=plt.figure(figsize=(16,8))
    plt.subplot(1,2,1)
    plt.imshow(src_slice, cmap='gray')
    plt.title('Training source (z = ' + str(z) + ')', fontsize=15)
    plt.axis('off')

    plt.subplot(1,2,2)
    plt.imshow(tgt_slice, cmap='magma')
    plt.title('Training target (z = ' + str(z) + ')', fontsize=15)
    plt.axis('off')

print('This is what the training images will look like with the chosen settings')
interact(scroll_in_z, z=widgets.IntSlider(min=1, max=src_sample.shape[0], step=1, value=0));

# Save model parameters
params =  {'training_source': training_source,
           'training_target': training_target,
           'model_name': model_name,
           'model_path': model_path,
           'number_of_epochs': number_of_epochs,
           'batch_size': batch_size,
           'training_shape': training_shape,
           'downscaling': downscaling_in_xy,
           'true_patch_size': true_patch_size,
           'val_split': validation_split_in_percent/100,
           'random_crop': random_crop,
           'data_augmentation': apply_data_augmentation}

params_df = pd.DataFrame.from_dict(params, orient='index')
# Check if file is actually made
params_df.to_csv(os.path.join(full_model_path, 'params.csv'))

# **4. Train the network**
---

## **4.1. Train the network**
---


<font size = 4>**CRITICAL NOTE:** Google Colab has a time limit for processing (to prevent using GPU power for datamining). Training times must be less than 12 hours! If training takes longer than 12 hours, please decrease `number_of_epochs`.

In [None]:
#@markdown ##Show model summary
model.summary()

In [None]:
#@markdown ##Start Training

# Start Training
model.train(epochs=number_of_epochs,
            batch_size=batch_size,
            train_source=training_source,
            train_target=training_target,
            model_path=model_path,
            model_name=model_name,
            val_split=validation_split_in_percent/100,
            augment=apply_data_augmentation,
            ckpt_period=checkpointing_period,
            save_best_ckpt_only=save_best_only,
            ckpt_path=last_ckpt_path,
            random_crop=random_crop,
            downscaling=downscaling_in_xy)

print('Training successfully completed!')

##**4.3. Download your model(s) from Google Drive**


---
<font size = 4>Once training is complete, the trained model is automatically saved to your Google Drive, in the **`model_path`** folder that was specified in Section 3. Download the folder to avoid any unwanted surprises, since the data can be erased if you train another model using the same `model_path`.

# **5. Evaluate your model**
---

<font size = 4>In this section the newly trained model can be assessed for performance. This involves inspecting the loss function in Section 5.1. and employing more advanced metrics in Section 5.2.

<font size = 4>**We highly recommend performing quality control on all newly trained models.**



In [None]:
#@markdown ###Model to be evaluated:
#@markdown <font size = 3>If left blank, the latest model defined in Section 3 will be evaluated:

qc_model_name = "" #@param {type:"string"}
qc_model_path = "" #@param {type:"string"}

if len(qc_model_path) == 0 and len(qc_model_name) == 0:
    qc_model_name = model_name
    qc_model_path = model_path

full_qc_model_path = os.path.join(qc_model_path, qc_model_name)

if os.path.exists(full_qc_model_path):
    print(qc_model_name + ' will be evaluated')
else:
    W  = '\033[0m'  # white (normal)
    R  = '\033[31m' # red
    print(R+'!! WARNING: The chosen model does not exist !!'+W)
    print('Please make sure you provide a valid model path and model name before proceeding further.')


## **5.1. Inspecting loss function**
---

<font size = 4>**The training loss** is the error between prediction and target after each epoch calculated across the training data while the **validation loss** calculates the error on the (unseen) validation data. During training these values should decrease until converging at which point the model has been sufficiently trained. If the validation loss starts increasing while the training loss has plateaued, the model has overfit on the training data which reduces its ability to generalise. Aim to halt training before this point.

<font size = 4>**Note:** For a more in-depth explanation please refer to [this review](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6381354/) by Nichols et al.


<font size = 4>The accuracy is another performance metric that is calculated after each epoch. We use the [Sørensen–Dice coefficient](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient) to score the prediction accuracy. 



In [None]:
#@markdown ##Visualise loss and accuracy
lossDataFromCSV = []
vallossDataFromCSV = []
accuracyDataFromCSV = []
valaccuracyDataFromCSV = []

with open(full_qc_model_path + '/Quality Control/training_evaluation.csv', 'r') as csvfile:
    csvRead = csv.reader(csvfile, delimiter=',')
    next(csvRead)
    for row in csvRead:
        lossDataFromCSV.append(float(row[2]))
        vallossDataFromCSV.append(float(row[4]))
        accuracyDataFromCSV.append(float(row[1]))
        valaccuracyDataFromCSV.append(float(row[3]))

epochNumber = range(len(lossDataFromCSV))
plt.figure(figsize=(15,10))

plt.subplot(2,1,1)
plt.plot(epochNumber,lossDataFromCSV, label='Training loss')
plt.plot(epochNumber,vallossDataFromCSV, label='Validation loss')
plt.title('Training and validation loss', fontsize=14)
plt.ylabel('Loss', fontsize=12)
plt.xlabel('Epochs', fontsize=12)
plt.legend()

plt.subplot(2,1,2)
plt.plot(epochNumber,accuracyDataFromCSV, label='Training accuracy')
plt.plot(epochNumber,valaccuracyDataFromCSV, label='Validation accuracy')
plt.title('Training and validation accuracy', fontsize=14)
plt.ylabel('Dice', fontsize=12)
plt.xlabel('Epochs', fontsize=12)
plt.legend()
plt.savefig(full_qc_model_path + '/Quality Control/lossCurvePlots.png')
plt.show()



## **5.2. Error mapping and quality metrics estimation**
---
<font size = 4>This section will provide both a visual indication of the model performance by comparing the overlay of the predicted and source volume.

In [None]:
#@markdown ##Compare prediction and ground-truth on testing data

#@markdown <font size = 4>Provide an unseen annotated dataset to determine the performance of the model:

testing_source = "" #@param{type:"string"}
testing_target = "" #@param{type:"string"}

qc_dir = full_qc_model_path + '/Quality Control'
predict_dir = qc_dir + '/Prediction'
if os.path.exists(predict_dir):
    shutil.rmtree(predict_dir)

os.makedirs(predict_dir)

predict_path = predict_dir + '/' + os.path.splitext(os.path.basename(testing_source))[0] + '_prediction.tif'

def last_chars(x):
    return(x[-11:])

try:
    ckpt_dir_list = glob(full_qc_model_path + '/ckpt/*')
    ckpt_dir_list.sort(key=last_chars)
    last_ckpt_path = ckpt_dir_list[0]
    print('Predicting from checkpoint:', os.path.basename(last_ckpt_path))
except IndexError:
    raise CheckpointError('No previous checkpoints were found, please retrain model.')

# Load parameters
params = pd.read_csv(os.path.join(full_qc_model_path, 'params.csv'), names=['val'], header=0, index_col=0)   

model = Unet3D(shape=params.loc['training_shape', 'val'])

prediction = model.predict(testing_source, last_ckpt_path, downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])

tifffile.imwrite(predict_path, prediction.astype('float32'), imagej=True)

print('Predicted images!')

qc_metrics_path = full_qc_model_path + '/Quality Control/QC_metrics_' + qc_model_name + '.csv'

test_target = tifffile.imread(testing_target)
test_source = tifffile.imread(testing_source)
test_prediction = tifffile.imread(predict_path)

def scroll_in_z(z):

    plt.figure(figsize=(25,5))
    # Source
    plt.subplot(1,4,1)
    plt.axis('off')
    plt.imshow(test_source[z-1], cmap='gray')
    plt.title('Source (z = ' + str(z) + ')', fontsize=15)

    # Target (Ground-truth)
    plt.subplot(1,4,2)
    plt.axis('off')
    plt.imshow(test_target[z-1], cmap='magma')
    plt.title('Target (z = ' + str(z) + ')', fontsize=15)

    # Prediction
    plt.subplot(1,4,3)
    plt.axis('off')
    plt.imshow(test_prediction[z-1], cmap='magma')
    plt.title('Prediction (z = ' + str(z) + ')', fontsize=15)
    
    # Overlay
    plt.subplot(1,4,4)
    plt.axis('off')
    plt.imshow(test_target[z-1], cmap='Greens')
    plt.imshow(test_prediction[z-1], alpha=0.5, cmap='Purples')
    plt.title('Overlay (z = ' + str(z) + ')', fontsize=15)

interact(scroll_in_z, z=widgets.IntSlider(min=1, max=test_source.shape[0], step=1, value=0));

## **5.3. Determine best Intersection over Union and threshold**
---
<font size = 4>This section will provide both a visual and a quantitative indication of the model performance by comparing the overlay of the predicted and source volume, as well as computing the highest [**Intersection over Union**](https://en.wikipedia.org/wiki/Jaccard_index) (IoU) score. The IoU is also known as the Jaccard Index.  

<font size = 4>The best threshold is calculated using the IoU. Each threshold value from 0 to 255 is tested and the threshold with the highest score is deemed the best. The IoU is calculated for the entire volume in 3D.

In [None]:
#@markdown ##Calculate Intersection over Union and best threshold 
prediction = tifffile.imread(predict_path)
prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))

target = tifffile.imread(testing_target).astype(np.bool)

def iou_vs_threshold(prediction, target):
    threshold_list = []
    IoU_scores_list = []

    for threshold in range(0,256): 
        mask = prediction > threshold

        intersection = np.logical_and(target, mask)
        union = np.logical_or(target, mask)
        iou_score = np.sum(intersection) / np.sum(union)

        threshold_list.append(threshold)
        IoU_scores_list.append(iou_score)

    return threshold_list, IoU_scores_list

threshold_list, IoU_scores_list = iou_vs_threshold(prediction, target)
thresh_arr = np.array(list(zip(threshold_list, IoU_scores_list)))
best_thresh = int(np.where(thresh_arr == np.max(thresh_arr[:,1]))[0])
best_iou = IoU_scores_list[best_thresh]

print('Highest IoU is {:.4f} with a threshold of {}'.format(best_iou, best_thresh))

def adjust_threshold(threshold, z):

    f=plt.figure(figsize=(25,5))
    plt.subplot(1,4,1)
    plt.imshow((prediction[z-1] > threshold).astype('uint8'), cmap='magma')
    plt.title('Prediction (Threshold = ' + str(threshold) + ')', fontsize=15)
    plt.axis('off')

    plt.subplot(1,4,2)
    plt.imshow(target[z-1], cmap='magma')
    plt.title('Target (z = ' + str(z) + ')', fontsize=15)
    plt.axis('off')

    plt.subplot(1,4,3)
    plt.axis('off')
    plt.imshow(test_source[z-1], cmap='gray')
    plt.imshow((prediction[z-1] > threshold).astype('uint8'), alpha=0.4, cmap='Reds')
    plt.title('Overlay (z = ' + str(z) + ')', fontsize=15)

    plt.subplot(1,4,4)
    plt.title('Threshold vs. IoU', fontsize=15)
    plt.plot(threshold_list, IoU_scores_list)
    plt.plot(threshold, IoU_scores_list[threshold], 'ro')     
    plt.ylabel('IoU score')
    plt.xlabel('Threshold')
    plt.show()

interact(adjust_threshold, 
         threshold=widgets.IntSlider(min=0, max=255, step=1, value=best_thresh),
         z=widgets.IntSlider(min=1, max=prediction.shape[0], step=1, value=0));

# **6. Using the trained model**

---

<font size = 4>Once sufficient performance of the trained model has been established using Section 5, the network can be used to segment unseen volumetric data.

## **6.1. Generate predictions from unseen dataset**
---

<font size = 4>The most recently trained model can now be used to predict segmentation masks on unseen images. If you want to use an older model, leave `model_path`  blank. Predicted output images are saved in `output_path` as Image-J compatible TIFF files.

## **Prediction parameters**

* <font size = 4>**`source_path`** specifies the location of the source 
image volume.

* <font size = 4>**`output_path`** specified where the output predictions are stored.

* <font size = 4>**`threshold`** can be calculated in Section 5 and is used to generate binary masks from the predictions.

* <font size = 4>**`big_tiff`** should be chosen if the expected prediction exceeds 4GB. The predictions will be saved using the BigTIFF format. Beware that this might substantially reduce the prediction speed. *Default: False* 

* <font size = 4>**`prediction_depth`** is only relevant if the prediction is saved as a BigTIFF. The prediction will not be performed in one go to not deplete the memory resources. Instead, the prediction is iteratively performed on a subset of the entire volume  with shape `(source.shape[0], source.shape[1], prediction_depth)`. *Default: 32*

* <font size = 4>**`model_path`** specifies the path to a model other than the most  recently trained.

In [None]:
#@markdown ## Download example volume

#@markdown <font size = 4> This can take up to an hour

import requests  
import os
from tqdm.notebook import tqdm 


def download_from_url(url, save_as):
    file_url = url
    r = requests.get(file_url, stream=True)  
  
    with open(save_as, 'wb') as file:  
        for block in tqdm(r.iter_content(chunk_size = 1024), desc = 'Downloading ' + os.path.basename(save_as), total=3275073, ncols=1000):
            if block:
                file.write(block)  

download_from_url('https://documents.epfl.ch/groups/c/cv/cvlab-unit/www/data/%20ElectronMicroscopy_Hippocampus/volumedata.tif', 'example_dataset/volumedata.tif')

In [None]:
#@markdown ### Provide the path to your dataset and to the folder where the predictions are saved, then run the cell to predict outputs from your unseen images.

source_path = "" #@param {type:"string"}
output_directory = "" #@param {type:"string"}

if not os.path.exists(output_directory):
    os.makedirs(output_directory)

output_path = os.path.join(output_directory, os.path.splitext(os.path.basename(source_path))[0] + '_predicted.tif')
#@markdown ###Prediction parameters:

save_probability_map = False #@param {type:"boolean"}

#@markdown <font size = 3>Determine best threshold in Section 5.2.

use_calculated_threshold = True #@param {type:"boolean"}
threshold =  200#@param {type:"number"}

# Tifffile library issues means that images cannot be appended to 
#@markdown <font size = 3>Choose if prediction file exceeds 4GB or if input file is very large (above 2GB). Image volume saved as BigTIFF.
big_tiff = False #@param {type:"boolean"}

#@markdown <font size = 3>Reduce `prediction_depth` if runtime runs out of memory during prediction. Only relevant if prediction saved as BigTIFF

prediction_depth = 32 #@param {type:"number"}

#@markdown ###Model to be evaluated
#@markdown <font size = 3>If left blank, the latest model defined in Section 3 will be evaluated

full_model_path_ = "" #@param {type:"string"}

if len(full_model_path_) == 0:
    full_model_path_ = os.path.join(model_path, model_name) 



# Load parameters
params = pd.read_csv(os.path.join(full_model_path_, 'params.csv'), names=['val'], header=0, index_col=0)   
model = Unet3D(shape=params.loc['training_shape', 'val'])

if use_calculated_threshold:
    threshold = best_thresh

def last_chars(x):
    return(x[-11:])

try:
    ckpt_dir_list = glob(full_model_path_ + '/ckpt/*')
    ckpt_dir_list.sort(key=last_chars)
    last_ckpt_path = ckpt_dir_list[0]
    print('Predicting from checkpoint:', os.path.basename(last_ckpt_path))
except IndexError:
    raise CheckpointError('No previous checkpoints were found, please retrain model.')

src = tifffile.imread(source_path)

if src.nbytes >= 4e9:
    big_tiff = True
    print('The source file exceeds 4GB in memory, prediction will be saved as BigTIFF!')

if not big_tiff:
    prediction = model.predict(src, last_ckpt_path, downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])
    prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))
    prediction = (prediction > threshold).astype('float32')

    tifffile.imwrite(output_path, prediction, imagej=True)

else:
    with tifffile.TiffWriter(output_path, bigtiff=True) as tif:
        for i in tqdm(range(0, src.shape[0], prediction_depth)):
            prediction = model.predict(src, last_ckpt_path, z_range=(i,i+prediction_depth), downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])
            prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))
            prediction = (prediction > threshold).astype('float32')
            
            for j in range(prediction.shape[0]):
                tif.save(prediction[j])

if save_probability_map:
    prob_map_path = os.path.splitext(output_path)[0] + '_prob_map.tif'
    
    if not big_tiff:
        prediction = model.predict(src, last_ckpt_path, downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])
        prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))
        tifffile.imwrite(prob_map_path, prediction.astype('float32'), imagej=True)

    else:
        with tifffile.TiffWriter(prob_map_path, bigtiff=True) as tif:
            for i in tqdm(range(0, src.shape[0], prediction_depth)):
                prediction = model.predict(src, last_ckpt_path, z_range=(i,i+prediction_depth), downscaling=params.loc['downscaling', 'val'], true_patch_size=params.loc['true_patch_size', 'val'])
                prediction = np.interp(prediction, (prediction.min(), prediction.max()), (0, 255))
                
                for j in range(prediction.shape[0]):
                    tif.save(prediction[j])

print('Predictions saved as', output_path)

src_volume = tifffile.imread(source_path)
pred_volume = tifffile.imread(output_path)

def scroll_in_z(z):
  
    f=plt.figure(figsize=(25,5))
    plt.subplot(1,2,1)
    plt.imshow(src_volume[z-1], cmap='gray')
    plt.title('Source (z = ' + str(z) + ')', fontsize=15)
    plt.axis('off')

    plt.subplot(1,2,2)
    plt.imshow(pred_volume[z-1], cmap='magma')
    plt.title('Prediction (z = ' + str(z) + ')', fontsize=15)
    plt.axis('off')

interact(scroll_in_z, z=widgets.IntSlider(min=1, max=src_volume.shape[0], step=1, value=0));


## **6.2. Download your predictions**
---

<font size = 4>**Store your data** and ALL its results elsewhere by downloading it from Google Drive and after that clean the original folder tree (datasets, results, trained model etc.) if you plan to train or use new networks. Please note that the notebook will otherwise **OVERWRITE** all files which have the same name.

In [None]:
#@markdown ##Download model directory
#@markdown 1.  <font size = 4>Specify the model_path
#@markdown 2.  <font size = 4>Run this cell to zip the model directory
#@markdown 3.  <font size = 4>Download the zipped file from the *Files* tab on the left

from google.colab import files

model_path_download = "" #@param {type:"string"}

if len(model_path_download) == 0:
    model_path_download = model_path

model_path_download = os.path.basename(model_path_download)

print('Zipping', model_path_download)

zip_model_path = model_path_download + '.zip'

!zip -r $zip_model_path $model_path_download

print('Successfully saved zipped model directory as', zip_model_path)


#**Thank you for using 3D U-Net!**