# Example 1A - Virtual staining of brightfield images (60x)

Example code to virtually stain brightfield images captured with the 60x magnification objective obtaining the corresponding images for nuclei, lipids and cytoplasm.

This example can be used to virtually stain other brightfield images by changing the loading and saving user-defined in section 1.1. below.

version 1.0 <br />
15 November 2020 <br />
Benjamin Midtvedt, Jesús Pineda Castro, Saga Helgadottir, Daniel Midtvedt & Giovanni Volpe <br />
Soft Matter Lab @ GU <br />
http://www.softmatterlab.org

## 0. Imports
 
Import all necessary packages. These include standard Python packages as well as the core of DeepTrack 2.0 (`deeptrack`) and some specialized classes for this virtual staining (`apido`).

In [1]:
import os
import re
import glob
import itertools
from timeit import default_timer as timer

import numpy as np
from PIL import Image

# DeepTrack 2.0 code
import apido
from apido import deeptrack as dt

## 1. Define input and output

Set constants to determine the input and output images

### 1.1 User-defined constants to load test images and save the virtually stained images

Constants defined by the user:

* `DATASET_PATH`: Input path (not including the magnification folder)

* `OUTPUT_PATH`: Output path (not including the magnication folder)


In [2]:
DATASET_PATH = "./validation_data/60x images/" 
OUTPUT_PATH = "./validation_results/60x images/"

### 1.2 Inferred constants

Constants automatically inferred from the user input

In [3]:
MAGNIFICATION = "60x"
file_name_struct = "AssayPlate_Greiner_#655090_{0}_T0001F{1}L01A0{2}Z0{3}C0{2}.tif"

Infer and normalize paths

In [4]:
DATASET_PATH = os.path.normpath(DATASET_PATH)
OUTPUT_PATH = os.path.normpath(OUTPUT_PATH)

PATH_TO_MODEL = os.path.normpath(
    os.path.abspath(
        os.path.join("models", MAGNIFICATION)
    )
)

In [5]:
print("Loading images from: \t\t", DATASET_PATH)
print("Saving results to: \t\t", OUTPUT_PATH)
print("Loading pretrained model from: \t", PATH_TO_MODEL)

Loading images from: 		 validation_data/60x images
Saving results to: 		 validation_results/60x images
Loading pretrained model from: 	 /workspace/Team-Soft-Matter-Lab-GU/models/60x


## 2. Load model

Load the pretrained virtual stainer from the local path.

Note that we expect here some warnings about overwriting `groups`, which are not a problem.

In [6]:
virtual_stainer = apido.load_model(PATH_TO_MODEL)
virtual_stainer.summary()



Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, None, None,  0                                            
__________________________________________________________________________________________________
lambda (Lambda)                 (None, None, None, 7 0           input_1[0][0]                    
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, None, None, 3 2048        lambda[0][0]                     
__________________________________________________________________________________________________
instance_normalization (Instanc (None, None, None, 3 64          conv2d[0][0]                     
_______________________________________________________________________________________

## 3. Load input data

We define a data pipeline to load brightfield images from storage. This uses DeepTrack 2.0, and follows the structure:

1. Load each z-slice of a well-site combination and concatenate them.
2. Pad the volume such that the first two dimensions are multiples of 32 (required by the model).
3. Correct for misalignment of the fluorescence channel and the brightfield channel (by a pre-calculated parametrization of the offset as a function of magnification and the site).

### 3.1 Find all wells and sites

We create an list over each well and site.

In [7]:
file_list = glob.glob(os.path.join(DATASET_PATH, "*C01.tif"))

SITES = [re.findall("F([0-9]{3})", f)[-1] for f in file_list]
WELLS = [re.findall("_([A-Z][0-9]{2})_", f)[-1] for f in file_list]

wells_and_sites = list(zip(WELLS, SITES))

### 3.2 The root feature

We use DeepTrack 2.0 to define the data loader pipeline. The pipeline is a sequence of `features`, which perform computations, controlled by `properties`, which are defined when creating the features. (Note that we any property with any name and value to a feature; if a property is not used by the feature, we refer to it as a dummy property.)

The feature `root` is a `DummyFeature`, which is just a container of dummy properties and does not perform any computations.
It takes the following arguments:

* `well_site_tuple` is a dummy property that cycles through the well-site combinations in `wells_and_sites`
* `well` is a dummy property that extracts the well from the `well_site_tuple`
* `site` is a dummy property that extracts the site from the `well_site_tuple`

Note that `well` and `site` are functions that take `well_site_tuple` as argument. These are dependent properties, and DeepTrack 2.0 will automatically ensure that they receive the correct input.  

In [8]:
root = dt.DummyFeature(
    well_site_tuple=itertools.cycle(wells_and_sites), # On each update, root will grab the next value from this iterator
    well=lambda well_site_tuple: well_site_tuple[0],  # Grabs the well from the well_site_tuple
    site=lambda well_site_tuple: well_site_tuple[1],  # Grabs the site from the well_site_tuple
)

### 3.3 The brightfield image loader

We use `deeptrack.LoadImage` to load and concatenate a brightfield stack. It takes the following arguments:

* `**root.properties` means that we take the properties of `root` (of importance `well` and `site`). The other properties of LoadImage will now depend on these.
* `file_names` is a dummy property, which takes the current well and site as input, and creates a list of file names that we want to load.
* `path` is a property used by `LoadImage` to determine which files to load. We calculate it by taking `file_names` as input and returning a list of paths using `os.path.join`.

Since `path` is a list, `LoadImage` stacks the images along the last dimension, creating a shaped volume with dimensions (width, height, 7).

In [9]:
brightfield_loader = dt.LoadImage(
    **root.properties,
    file_names=lambda well, site: [file_name_struct.format(well, site, 4, z) for z in range(1, 8)],
    path=lambda file_names: [os.path.join(DATASET_PATH, file_name) for file_name in file_names],
)

### 3.4 Padding

The model requires the two primary dimensions of the input to be multiples of 32. We ensure this using `deeptrack.PadToMultiplesOf`. This feature also adds a property which allows us to restore the model prediction to the original shape. It takes the following argument:

* `multiple` is a tuple of multiples per dimension. In our case, we set the first two dimentsions to 32 and the third dimension to `None` (meaning that we do not want to pad it).

In [10]:
ensure_padded = dt.PadToMultiplesOf(multiple=(32, 32, None))

### 3.5 Offset adjustment

Offset adjustments using affine transformations. The offset is parametrized as a function of the magnification and the site as described in the report.

The properties are set as follows:
* `translate` sets how much we translate the image in pixels. It is a tuple representing the (x, y) shift. We calculate it as a function of the angular position of the site within the well, with site 1 at angle 0.
* `scale` sets how much we rescale the image.
* `angle` is a dummy property that calculates the angle of the site in radians.

In [1]:
# Affine transformation parameters (precalculate, see report)
Ax = +2.3054
Bx = -0.0315
Ay = -0.1352
By = -2.3049
x =  -0.8363
y =  +0.8081
scale= 0.99975

correct_offset = dt.Affine(
    translate=lambda angle: (
        (np.cos(angle) * Bx + np.sin(angle) * Ax + x), # Offset in x
        (np.cos(angle) * By + np.sin(angle) * Ay + y), # Offset in y
    ),
    scale=scale,
    angle = lambda site: (int(site) - 1) * np.pi / 6,
    **root.properties,
)

NameError: name 'dt' is not defined

### 3.6 Create the pipeline 

We use the `+` operator to chain the features, defining the execution order. In DeepTrack 2.0, this means that the output of the feature on the left is passed as the input to the feature on the right. In other words, the stack loaded by `brightfield_loader` is passed to `ensure_padded`, the output of which is offset-corrected by `correct_offset`. 

In [12]:
brightfield_stack_pipeline = brightfield_loader + ensure_padded + correct_offset

## 4. Calculate the virtually stained images from the brightfield images

In order to have the actual brightfield images needed for further processing, we need to resolve the pipeline `brightfield_stack_pipeline`.

Each time we call `update()`, we update each property of all the features in the pipeline, ensuring that each dependent property is updated in the right order. Therefore, at each update, we select a new well-site combination, from which we obtain all other properties (the path of the images, the angle of the site, etc.).

The subsequent call to `resolve()` executes each feature in the pipeline in order, producing a brightfield stack.

We do this once per image we want to load, for a total number of images set in `num_files`.

In [13]:
num_files = len(wells_and_sites)

print("Loading {0} files".format(num_files))
list_of_brightfield_images = [
    brightfield_stack_pipeline.update().resolve() for _ in range(num_files)
]

Loading 15 files


Time the execution time. Set `TIME_EXECUTION` to `False` to bypass!

In [14]:
TIME_EXECUTION = True
ITERATIONS = 15

model_input = np.array(list_of_brightfield_images)

if TIME_EXECUTION:
    
    timings = []
    
    for iteration in range(ITERATIONS):
        start = timer()
        virtual_stainer.predict(model_input, batch_size=4)
        end = timer()
        timings.append(end - start)

        print("Finished iteration", iteration)
    
    num_predictions = num_files * (ITERATIONS - 1)
    timings = timings[1:] # Skipping first iteration to ignore cold-start

    
    print(
        "Made {0} predictions in {1} seconds. Median {2} seconds per image, minimum of {3} seconds, not bad!"
        .format(num_predictions, np.sum(timings), np.median(timings) / num_files, np.min(timings) / num_files)
    )

Finished iteration 0
Finished iteration 1
Finished iteration 2
Finished iteration 3
Finished iteration 4
Finished iteration 5
Finished iteration 6
Finished iteration 7
Finished iteration 8
Finished iteration 9
Finished iteration 10
Finished iteration 11
Finished iteration 12
Finished iteration 13
Finished iteration 14
Made 210 predictions in 58.90760607365519 seconds. Median 0.2724334174146255 seconds per image, minimum of 0.26199914862712226 seconds, not bad!


Finally, we calculate the virtual staining. All we need to do is to call `virtual_stainer.predict`, no further pre- or post-processing needed.

In [15]:
stains = virtual_stainer.predict(model_input, batch_size=4)

## 5. Save results

Finally, we iterate over all the virtually stained images and store the results to memory. In DeepTrack 2.0, the properties used to create an image with `resolve()` are stored in a field called `properties`, and can easily be retrieved using the utility method `get_property`. Here, we use this information to extract the well and site of the image, which is needed to correctly name the file. Moreover, some features save additional values. One such case is `undo_padding` which is saved by all padding features. This is a tuple of slices that return an numpy array to its original size before padding.

In [16]:
output_format = os.path.join(OUTPUT_PATH, file_name_struct)
os.makedirs(OUTPUT_PATH, exist_ok=True)

for brightfield, prediction in zip(list_of_brightfield_images, stains):
    
    well = brightfield.get_property("well")
    site = brightfield.get_property("site")
    
    # Undo the padding required by the model
    undo_padding = brightfield.get_property("undo_padding")
    prediction = prediction[undo_padding]
    
    for action in range(3):
        file_path = output_format.format(well, site, action + 1, 1)
        
        prediction_layer = Image.fromarray(prediction[..., action].astype(np.uint16))
        prediction_layer.save(file_path)

        print("Saved image to:", file_path)

Saved image to: validation_results/60x images/AssayPlate_Greiner_#655090_B03_T0001F001L01A01Z01C01.tif
Saved image to: validation_results/60x images/AssayPlate_Greiner_#655090_B03_T0001F001L01A02Z01C02.tif
Saved image to: validation_results/60x images/AssayPlate_Greiner_#655090_B03_T0001F001L01A03Z01C03.tif
Saved image to: validation_results/60x images/AssayPlate_Greiner_#655090_B03_T0001F002L01A01Z01C01.tif
Saved image to: validation_results/60x images/AssayPlate_Greiner_#655090_B03_T0001F002L01A02Z01C02.tif
Saved image to: validation_results/60x images/AssayPlate_Greiner_#655090_B03_T0001F002L01A03Z01C03.tif
Saved image to: validation_results/60x images/AssayPlate_Greiner_#655090_B03_T0001F010L01A01Z01C01.tif
Saved image to: validation_results/60x images/AssayPlate_Greiner_#655090_B03_T0001F010L01A02Z01C02.tif
Saved image to: validation_results/60x images/AssayPlate_Greiner_#655090_B03_T0001F010L01A03Z01C03.tif
Saved image to: validation_results/60x images/AssayPlate_Greiner_#655090_