# Train 60x Virtual Stainer

Notebook by Group 1: Team-Soft-Matter-Lab-GU

In [1]:
# Globals
import os
import glob
import itertools
from timeit import default_timer as timer

# Locals
import apido
import deeptrack as dt

# Packages
import numpy as np
from PIL import Image

## 1. Constants

Sets of constants used by the notebook

### 1.1 User constants

Constants set by the user

* `DATASET_PATH`: Input path (not including the magnification folder)

* `OUTPUT_PATH`: Output path (not including the magnication folder)

* `WELLS`: Name of the wells to predict on

* `SITES`: "all" or list of integers

In [2]:
GENERATOR_BREADTH = 16
GENERATOR_DEPTH = 5
DISCRIMINATOR_DEPTH = 5
MAE_LOSS_WEIGHT = 0.001

DATASET_PATH = "./local_data/" 
OUTPUT_PATH = "./models/"
WELLS = ["C02"]
SITES = [1, 2, 3, 4]

### 1.2 Inferred constants

Constants inferred from the user input

In [4]:
MAGNIFICATION = "60x"
file_name_struct = "AssayPlate_Greiner_#655090_{0}_T0001F{1}L01A0{2}Z0{3}C0{2}.tif"

PATH_TO_OUTPUT = os.path.normpath(OUTPUT_PATH)

NameError: name 'os' is not defined

Infer full path to dataset

In [None]:
_glob_struct = os.path.join(DATASET_PATH, MAGNIFICATION + "*/")
_glob_results = glob.glob(_glob_struct)

if len(_glob_results) == 0:
    raise ValueError("No path found matching glob {0}".format(_glob_struct))
elif len(_glob_results) > 1:
    from warnings import warn
    warn("Multiple paths found! Using {0}".format(_glob_results[0]))

PATH_TO_MAGNIFICATION = os.path.normpath(_glob_results[0])

In [None]:
print("Loading images from: \t", PATH_TO_MAGNIFICATION)
print("Saving results to: \t", PATH_TO_OUTPUT)

## 2. Load model

We load the virtual stainer from the local path.
This is expected to warn about overwriting `groups`.

## 3. Data loader

We define a data pipeline for loading images from storage. This uses DeepTrack 2.0, and follows the structure of

1. Load each z-slice of an well-site combination and concatenate them.
2. Pad the volume such that the first two dimensions are multiples of 32 (required by the model).
3. Correct for misalignment of the fluorescence channel and the brightfield channel (by a pre-calculated parametrization of the offset as a function of magnification and the site.)


### 3.1 Find all wells and sites

We create an iterator over each well and site. `Itertools.product` produces an iterator over each combination of its input. In this case, each site in each well.

In [5]:
wells_and_sites = list(
    itertools.product(
        WELLS,
        SITES if isinstance(SITES, list) else range(1, 13) 
    )
)


In [None]:
random.seed(1)
random.shuffle(wells_and_sites)

split = int(len(wells_and_sites) * 0.85)

training_set = wells_and_sites[:split]
validation_set = wells_and_sites[split:]

print("Training on {0} images".format(len(training_set)))
print("Validating on {0} images".format(len(validation_set)))

### 3.2 The root feature

We use DeepTrack 2 to define the data loader pipeline. The pipeline is a sequence of `features`, which perform computations. They are controlled by `properties`, which we pass when we create the feature.

As an example, `root` is a `DummyFeature`, which does not perform any computations, and is instead just a container of properties. Improtant to note, we can pass any argument of any name to the `feature`. If it is not used by the `feature`, we refer to it as a dummy property.

It takes the following arguments:

* `well_site_tuple` is a dummy property that cycles through the well-site combinations in `wells_and_sites`
* `well` is a dummy property that extracts the well from the `well_site_tuple`
* `site` is a dummy property that extracts the site from the `well_site_tuple`

Note that `well` and `site` are functions that take `well_site_tuple` as argument. These are dependent properties, and deeptrack will automatically ensure that they receive the correct input.  

In [6]:
training_iterator = itertools.cycle(training_set)
validation_iterator = itertools.cycle(validation_set)

def get_next_well_and_site(validation):
    if validation:
        return next(validation_iterator)
    else:
        return next(training_iterator)

# Accepts a tuple of form (well, site), and returns the well
def get_well_from_tuple(well_site_tuple):
    return well_site_tuple[0]

# Accepts a tuple of form (well, site), and returns the site as 
# a string formated to be of length 3.
def get_site_from_tuple(well_site_tuple):
    site_string = "00" + str(well_site_tuple[1])
    return site_string[-3:]



root = dt.DummyFeature(
    well_site_tuple=get_next_well_and_site,           # On each update, root will grab the next value from this iterator
    well=get_well_from_tuple,                         # Grabs the well from the well_site_tuple
    site=get_site_from_tuple,                         # Grabs and formats the site from the well_site_tuple
)

### 3.3 The brightfield loader

We use `deeptrack.LoadImage` to load and concatenate a brightfield stack.

It takes the following arguments:

* `**root.properties` means that we grab the properties of `root` (of importance `well` and `site`). Other properties of LoadImage can now depend on these.
* `file_names` is a dummy property, which takes the current well and site as input, and creates a list of file names that we want to load.
* `path` is a property used by `LoadImage` to determine which files to load. We calculate it by taking `file_names` as input and returning a list of paths using `os.path.join`.

Since `path` is a list, `LoadImage` stacks the images along the last dimension, creating a (width, height, 7) shaped volume.

In [7]:
brightfield_loader = dt.LoadImage(
    **root.properties,
    file_names=lambda well, site: [file_name_struct.format(well, site, 4, z) for z in range(1, 8)],
    path=lambda file_names: [os.path.join(PATH_TO_MAGNIFICATION, file_name) for file_name in file_names],
)

### 3.4 The fluorescence loader

We use `deeptrack.LoadImage` to load and concatenate a fluorescence stack.

It takes the following arguments:

* `**root.properties` means that we grab the properties of `root` (of importance `well` and `site`). Other properties of LoadImage can now depend on these.
* `file_names` is a dummy property, which takes the current well and site as input, and creates a list of file names that we want to load.
* `path` is a property used by `LoadImage` to determine which files to load. We calculate it by taking `file_names` as input and returning a list of paths using `os.path.join`.

Since `path` is a list, `LoadImage` stacks the images along the last dimension, creating a (width, height, 3) shaped volume.

In [None]:
fluorescence_loader = dt.LoadImage(
    **root.properties,
    file_names=lambda well, site: [file_name_struct.format(well, site, action, 1) for action in range(1, 4)],
    path=lambda file_names: [os.path.join(PATH_TO_MAGNIFICATION, file_name) for file_name in file_names],
)

### 3.4 Offset adjustment

Offset adjustments using affine transformations. The offset is parametrized as a function of the magnification and the site.

The properties are set as follows:
* `translate` sets how much we translate the image in pixels. It is a tuple representing the (x, y) shift. We calculate it as a function of angle of the site, with site 1 at angle 0.
* `angle` is a dummy property that calculates the angle of the site in radians.

In [9]:
# Coefficients of the regression
Ax = +2.3054
Bx = -0.0315
Ay = -0.1352
By = -2.3049
x =  -0.8363
y =  +0.8081
scale= 0.99975

correct_offset = dt.Affine(
    translate=lambda angle: (
        (np.cos(angle) * Bx + np.sin(angle) * Ax + x) * -1, # Offset in x
        (np.cos(angle) * By + np.sin(angle) * Ay + y) * -1, # Offset in y
    ),
    angle = lambda site: (int(site) - 1) * np.pi / 6,
    **root.properties,
)

### 3.5 Augmentations

We use three kinds of augmentations: Mirroring, Affine transformations, Distortions

In [5]:
flip = dt.FlipLR()

affine = dt.Affine(
    rotate=lambda: np.random.rand() * 2 * np.pi,
    scale=lambda: np.random.rand() * 0.1 + 0.95,
    shear=lambda: np.random.rand() * 0.05 - 0.025
)

distortion = dt.ElasticTransformation(
    alpha=lambda: np.random.rand() * 80,
    sigma=lambda: 7
)

corner = 512 * (np.sqrt(2) - 1) / 2
cropping  = dt.Crop(
    crop=(512, 512, None),
    corner=(corner, corner, 0)
)

NameError: name 'dt' is not defined

### 3.5 Creating the pipeline 

We use the (`+`) operator to chain the features, defining the execution order.

In DeepTrack, this means that the output of the feature on the left, is passed as the input to the feature on the right.

In other words, the stack loaded by `brightfield_loader` is passed to `ensure_padded`, the output of which is offset-corrected by `correct_offset`. 

In [10]:
corrected_brightfield = brightfield_loader + correct_offset

data_pair = dt.Combine([corrected_brightfield, fluorescence_loader])

validation_set



In [6]:
padded_crop_size = int(512 * np.sqrt(2))

cropped_data = dt.Crop(
    data_pair,
    crop=(padded_crop_size, padded_crop_size, None),
    updates_per_reload=16,
    corner="random",
)

augmented_data = cropped_data + flip + affine + dist + cropping

NameError: name 'np' is not defined

In [None]:
validation_data = data_pair + dt.PadToMultiplesOf(multiple=(32, 32, None))

In [None]:
dataset = dt.ConditionalSetFeature(
    on_true=validation_data,
    on_false=augmented_data,
    condition="validation"
)

## 4. Define generator

We use generators to interface DeepTrack features with keras training. We have defined some special ones, which speed up training.

Here we will use `ContinuousGenerator`

In [11]:
generator = dt.generators.ContinuousGenerator(
    dataset,
    batch_function=lambda image: image[0],
    label_function=lambda image: image[1],
    batch_size=8,
    min_data_size=500,
    max_data_size=2000
)

Loading 4 files


## 5. Define model

In [12]:
GAN_generator = apido.generator(GENERATOR_BREADTH, GENERATOR_WIDTH)
GAN_discriminator = apido.discriminator(DISCRIMINATOR_DEPTH)

GAN = dt.models.cgan(
    generator=GAN_generator,
    discriminator=GAN_discriminator,
    discriminator_loss="mse",
    assemble_loss=["mse", "mae"],
    assemble_loss_weights=[1 - MAE_LOSS_WEIGHT, MAE_LOSS_WEIGHT]
)

GAN.compile(metrics=apido.metrics("60x"))

Made 4 predictions in 13.887231400000005 seconds. 3.471807850000001 seconds per image, not bad!


## 6. Train model

In [None]:
validation_data = zip(dataset.update(validation=True).resolve() for _ in range(len(validation_set)))

with generator:
    h = GAN.fit(
        generator, 
        epochs=500,
        steps_per_epoch=32,
        validation_data=validation_data,
        validation_batch_size=4
    )

## 5. Save results

Iterate over the stains, and store the results to memory. In DeepTrack, the properties used to create a image using `resolve()` are stored in a field called `properties`, and can easily be retrieved using the utility method `get_property`.

Here, we use this to extract the well and site of the image, which is needed to correctly name the file. Moreover, some features save additional values. One such case is `undo_padding` which is saved by all padding features. This is a tuple of slices that return an numpy array to its pre-padded size.

In [15]:
apido.save_training_results(...)

Saved image to: ./local_results/./local_data\60x images\AssayPlate_Greiner_#655090_C02_T0001F001L01A01Z01C01.tif
Saved image to: ./local_results/./local_data\60x images\AssayPlate_Greiner_#655090_C02_T0001F001L01A02Z01C02.tif
Saved image to: ./local_results/./local_data\60x images\AssayPlate_Greiner_#655090_C02_T0001F001L01A03Z01C03.tif
Saved image to: ./local_results/./local_data\60x images\AssayPlate_Greiner_#655090_C02_T0001F002L01A01Z01C01.tif
Saved image to: ./local_results/./local_data\60x images\AssayPlate_Greiner_#655090_C02_T0001F002L01A02Z01C02.tif
Saved image to: ./local_results/./local_data\60x images\AssayPlate_Greiner_#655090_C02_T0001F002L01A03Z01C03.tif
Saved image to: ./local_results/./local_data\60x images\AssayPlate_Greiner_#655090_C02_T0001F003L01A01Z01C01.tif
Saved image to: ./local_results/./local_data\60x images\AssayPlate_Greiner_#655090_C02_T0001F003L01A02Z01C02.tif
Saved image to: ./local_results/./local_data\60x images\AssayPlate_Greiner_#655090_C02_T0001F003