# Example 1B - Training of virtual staining of brightfield images (60x)

Example code to train a neural network to virtually stain brightfield images captured with the 60x magnification objective obtaining the corresponding images for nuclei, lipids and cytoplasm.

version 1.0 <br />
15 November 2020 <br />
Benjamin Midtvedt, Jesús Pineda Castro, Saga Helgadottir, Daniel Midtvedt & Giovanni Volpe <br />
Soft Matter Lab @ GU <br />
http://www.softmatterlab.org

## 0. Imports
 
Import all necessary packages. These include standard Python packages as well as the core of DeepTrack 2.0 (`deeptrack`) and some specialized classes for this virtual staining (`apido`).

In [8]:
import os
import glob
import random
import itertools
from timeit import default_timer as timer

import numpy as np
from PIL import Image

# DeepTrack 2.0 code
import deeptrack as dt
import apido

## 1. Define input and output

Set constants to determine the input and output images

### 1.1 Neural-network model parameters

Parameters of the neural network model. These are:

* `GENERATOR_BREADTH`: determines the width of the input image as `GENERATOR_BREADTH * 32` (e.g., `GENERATOR_BREADTH = 32` corresponds to an input image size `532`)

* `GENERATOR_DEPTH`: Depth of the generator U-Net

* `DISCRIMINATOR_DEPTH`: Depth of the discriminator convolutional encoder

* `MAE_LOSS_WEIGHT`: the weighting of the MAE loss vs. the adversarial loss

In [None]:
GENERATOR_BREADTH = 16
GENERATOR_DEPTH = 5
DISCRIMINATOR_DEPTH = 5
MAE_LOSS_WEIGHT = 0.001

### 1.2 User-defined constants for loading data and saving model

Constants defined by the user:

* `DATASET_PATH`: Input path (not including the magnification folder)

* `OUTPUT_PATH`: Output path (not including the magnication folder)

* `WELLS`: Name of the wells on which to predict

* `SITES`: "all" or list of integers

In [2]:
DATASET_PATH = "./train_data/" 
OUTPUT_PATH = "./models/"
WELLS = ["C02"]
SITES = [1, 2, 3, 4]

### 1.2 Inferred constants

Constants inferred from the user input

In [3]:
MAGNIFICATION = "60x"
file_name_struct = "AssayPlate_Greiner_#655090_{0}_T0001F{1}L01A0{2}Z0{3}C0{2}.tif"

PATH_TO_OUTPUT = os.path.normpath(OUTPUT_PATH)

Infer full path to dataset

In [4]:
_glob_struct = os.path.join(DATASET_PATH, MAGNIFICATION + "*/")
_glob_results = glob.glob(_glob_struct)

if len(_glob_results) == 0:
    raise ValueError("No path found matching glob {0}".format(_glob_struct))
elif len(_glob_results) > 1:
    from warnings import warn
    warn("Multiple paths found! Using {0}".format(_glob_results[0]))

PATH_TO_MAGNIFICATION = os.path.normpath(_glob_results[0])

In [5]:
print("Loading images from: \t", PATH_TO_MAGNIFICATION)
print("Saving results to: \t", PATH_TO_OUTPUT)

Loading images from: 	 local_data\60x images
Saving results to: 	 models


## 2. Load train data

We define a data pipeline for loading images from storage. This uses DeepTrack 2.0, and follows the structure of

1. Load each z-slice of a well-site combination and concatenate them.
2. Pad the volume such that the first two dimensions are multiples of 32 (required by the model).
3. Correct for misalignment of the fluorescence channel and the brightfield channel (by a pre-calculated parametrization of the offset as a function of magnification and the site).

### 2.1 Find all wells and sites

We create an iterator over each well and site. `Itertools.product` produces an iterator over each combination of its input. In this case, each site in each well.

In [6]:
wells_and_sites = list(
    itertools.product(
        WELLS,
        SITES if isinstance(SITES, list) else range(1, 13) 
    )
)

In [9]:
random.seed(1)
random.shuffle(wells_and_sites)

split = int(len(wells_and_sites) * 0.85)

training_set = wells_and_sites[:split]
validation_set = wells_and_sites[split:]

print("Training on {0} images".format(len(training_set)))
print("Validating on {0} images".format(len(validation_set)))

Training on 3 images
Validating on 1 images


### 2.2 The root feature

We use DeepTrack 2.0 to define the data loader pipeline. The pipeline is a sequence of `features`, which perform computations, controlled by `properties`, which are defined when creating the features. (Note that we any property with any name and value to a feature; if a property is not used by the feature, we refer to it as a dummy property.)

The feature `root` is a `DummyFeature`, which is just a container of dummy properties and does not perform any computations.
It takes the following arguments:

* `well_site_tuple` is a dummy property that cycles through the well-site combinations in `wells_and_sites`
* `well` is a dummy property that extracts the well from the `well_site_tuple`
* `site` is a dummy property that extracts the site from the `well_site_tuple`

Note that `well` and `site` are functions that take `well_site_tuple` as argument. These are dependent properties, and DeepTrack 2.0 will automatically ensure that they receive the correct input.

In [10]:
training_iterator = itertools.cycle(training_set)
validation_iterator = itertools.cycle(validation_set)

def get_next_well_and_site(validation):
    if validation:
        return next(validation_iterator)
    else:
        return next(training_iterator)

# Accepts a tuple of form (well, site), and returns the well
def get_well_from_tuple(well_site_tuple):
    return well_site_tuple[0]

# Accepts a tuple of form (well, site), and returns the site as 
# a string formated to be of length 3.
def get_site_from_tuple(well_site_tuple):
    site_string = "00" + str(well_site_tuple[1])
    return site_string[-3:]



root = dt.DummyFeature(
    well_site_tuple=get_next_well_and_site,           # On each update, root will grab the next value from this iterator
    well=get_well_from_tuple,                         # Grabs the well from the well_site_tuple
    site=get_site_from_tuple,                         # Grabs and formats the site from the well_site_tuple
)

### 2.3 The brightfield image loader

We use `deeptrack.LoadImage` to load and concatenate a brightfield stack. It takes the following arguments:

* `**root.properties` means that we take the properties of `root` (of importance `well` and `site`). The other properties of LoadImage will now depend on these.
* `file_names` is a dummy property, which takes the current well and site as input, and creates a list of file names that we want to load.
* `path` is a property used by `LoadImage` to determine which files to load. We calculate it by taking `file_names` as input and returning a list of paths using `os.path.join`.

Since `path` is a list, `LoadImage` stacks the images along the last dimension, creating a shaped volume with dimensions (width, height, 7).

In [11]:
brightfield_loader = dt.LoadImage(
    **root.properties,
    file_names=lambda well, site: [file_name_struct.format(well, site, 4, z) for z in range(1, 8)],
    path=lambda file_names: [os.path.join(PATH_TO_MAGNIFICATION, file_name) for file_name in file_names],
)

### 2.4 The fluorescence image loader

We use `deeptrack.LoadImage` to load and concatenate a fluorescence stack. It takes the following arguments:

* `**root.properties` means that we take the properties of `root` (of importance `well` and `site`). The other properties of LoadImage will now depend on these.
* `file_names` is a dummy property, which takes the current well and site as input, and creates a list of file names that we want to load.
* `path` is a property used by `LoadImage` to determine which files to load. We calculate it by taking `file_names` as input and returning a list of paths using `os.path.join`.

Since `path` is a list, `LoadImage` stacks the images along the last dimension, creating a shaped volume with dimensions (width, height, 3).

In [12]:
fluorescence_loader = dt.LoadImage(
    **root.properties,
    file_names=lambda well, site: [file_name_struct.format(well, site, action, 1) for action in range(1, 4)],
    path=lambda file_names: [os.path.join(PATH_TO_MAGNIFICATION, file_name) for file_name in file_names],
)

### 2.5 Offset adjustment

Offset adjustments using affine transformations. The offset is parametrized as a function of the magnification and the site as described in the report.

The properties are set as follows:
* `translate` sets how much we translate the image in pixels. It is a tuple representing the (x, y) shift. We calculate it as a function of the angular position of the site within the well, with site 1 at angle 0.
* `angle` is a dummy property that calculates the angle of the site in radians.

In [13]:
# Coefficients of the regression
Ax = +2.3054
Bx = -0.0315
Ay = -0.1352
By = -2.3049
x =  -0.8363
y =  +0.8081
scale= 0.99975

correct_offset = dt.Affine(
    translate=lambda angle: (
        (np.cos(angle) * Bx + np.sin(angle) * Ax + x) * -1, # Offset in x
        (np.cos(angle) * By + np.sin(angle) * Ay + y) * -1, # Offset in y
    ),
    angle = lambda site: (int(site) - 1) * np.pi / 6,
    **root.properties,
)

### 2.6 Define augmentations

We use three kinds of augmentations: Mirroring (`deeptrack.FlipLR`), Affine transformations (`deeptrack.ElasticTransformation`), and Distortions (`deeptrack.Crop`)

In [14]:
flip = dt.FlipLR()

affine = dt.Affine(
    rotate=lambda: np.random.rand() * 2 * np.pi,
    scale=lambda: np.random.rand() * 0.1 + 0.95,
    shear=lambda: np.random.rand() * 0.05 - 0.025
)

distortion = dt.ElasticTransformation(
    alpha=lambda: np.random.rand() * 80,
    sigma=lambda: 7
)

corner = 512 * (np.sqrt(2) - 1) / 2
cropping  = dt.Crop(
    crop=(512, 512, None),
    corner=(corner, corner, 0)
)

### 2.7 Create the pipeline 

We use the (`+`) operator to chain the features, defining the execution order. In DeepTrack 2.0, this means that the output of the feature on the left, is passed as the input to the feature on the right. This is done in the following steps:

1. `corrected_brightfield` is generated by offsetting the `brightfield_loader`
2. `data_pair` is created with input images and targets
3. `augmented_data` are defined by using the augmentations as well as the cropping
4. `validation_data` is created
5. `dataset` is created

In [15]:
corrected_brightfield = brightfield_loader + correct_offset

data_pair = dt.Combine([corrected_brightfield, fluorescence_loader])

padded_crop_size = int(512 * np.sqrt(2))

cropped_data = dt.Crop(
    data_pair,
    crop=(padded_crop_size, padded_crop_size, None),
    updates_per_reload=16,
    corner="random",
)

augmented_data = cropped_data + flip + affine + distortion + cropping

validation_data = data_pair + dt.PadToMultiplesOf(multiple=(32, 32, None))

dataset = dt.ConditionalSetFeature(
    on_true=validation_data,
    on_false=augmented_data,
    condition="validation"
)

[('C02', 2)]

## 3. Define generator

We use generators to interface DeepTrack 2.0 features with Keras training routines. In DeepTrack 2.0, we have defined some special generators that speed up training. Here, we will use `deeptrack.ContinuousGenerator`, which continuosly geenrate augmented training images and makes them available for training the neural network model.

In [21]:
generator = dt.generators.ContinuousGenerator(
    dataset,
    batch_function=lambda image: image[0],
    label_function=lambda image: image[1],
    batch_size=8,
    min_data_size=500,
    max_data_size=2000
)

## 4. Define model

Here, we use a GAN with a U-Net generator and a convolutional encoder discriminator. This uses MAE loss in the generator and MSE in the discriminator. The details are described in the report.

In [None]:
GAN_generator = apido.generator(GENERATOR_BREADTH, GENERATOR_WIDTH)
GAN_discriminator = apido.discriminator(DISCRIMINATOR_DEPTH)

GAN = dt.models.cgan(
    generator=GAN_generator,
    discriminator=GAN_discriminator,
    discriminator_loss="mse",
    assemble_loss=["mse", "mae"],
    assemble_loss_weights=[1 - MAE_LOSS_WEIGHT, MAE_LOSS_WEIGHT]
)

GAN.compile(metrics=apido.metrics("60x"))

## 5. Train model

We finally train the model for 500 epochs.

In [22]:
validation_data = zip(dataset.update(validation=True).resolve() for _ in range(len(validation_set)))

with generator:
    h = GAN.fit(
        generator, 
        epochs=500,
        steps_per_epoch=32,
        validation_data=validation_data,
        validation_batch_size=4
    )

Generating 0 / 500 samples before starting training

Exception in thread Thread-6:
Traceback (most recent call last):
  File "c:\users\bmidt\appdata\local\programs\python\python37\lib\threading.py", line 926, in _bootstrap_inner
    self.run()
  File "c:\users\bmidt\appdata\local\programs\python\python37\lib\threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "D:\Team-Soft-Matter-Lab-GU\src\deeptrack\generators.py", line 329, in _continuous_get_training_data
    new_image = self._get(self.feature, self.feature_kwargs)
  File "D:\Team-Soft-Matter-Lab-GU\src\deeptrack\generators.py", line 370, in _get
    return features.resolve(**feature_kwargs)
  File "D:\Team-Soft-Matter-Lab-GU\src\deeptrack\features.py", line 193, in resolve
    new_list = self._process_and_get(image_list, **feature_input)
  File "D:\Team-Soft-Matter-Lab-GU\src\deeptrack\features.py", line 357, in _process_and_get
    new_list = self.get(image_list, **feature_input)
  File "D:\Team-Soft-Matter-Lab-GU\src\deeptrack\features.py", line 725

Generating 0 / 500 samples before starting training

KeyboardInterrupt: 

## 6. Save model
