# Example 1B - Training of virtual staining of brightfield images (60x)

Example code to train a neural network to virtually stain brightfield images captured with the 60x magnification objective obtaining the corresponding images for nuclei, lipids and cytoplasm.

version 1.0 <br />
15 November 2020 <br />
Benjamin Midtvedt, Jesús Pineda Castro, Saga Helgadottir, Daniel Midtvedt & Giovanni Volpe <br />
Soft Matter Lab @ GU <br />
http://www.softmatterlab.org

## 0. Imports
 
Import all necessary packages. These include standard Python packages as well as the core of DeepTrack 2.0 (`deeptrack`) and some specialized classes for this virtual staining (`apido`).

In [1]:
import os
import glob
import random
import itertools
from timeit import default_timer as timer

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tensorflow.keras.optimizers import Adam

# DeepTrack 2.0 code
import apido
from apido import deeptrack as dt

## 1. Define input and output

Set constants to determine the input and output images

### 1.2 User-defined constants for loading data and saving model

Constants defined by the user:

* `DATASET_PATH`: Input path (not including the magnification folder)

* `OUTPUT_PATH`: Output path (not including the magnication folder)

* `WELLS`: Name of the wells on which to predict

* `SITES`: "all" or list of integers

In [3]:
MAGNIFICATION = "20x"
DATASET_PATH = "../astra_data_readonly/"

### 1.2 Inferred constants

Constants inferred from the user input

In [4]:
MAGNIFICATION = "60x"
file_name_struct = "AssayPlate_Greiner_#655090_{0}_T0001F{1}L01A0{2}Z0{3}C0{2}.tif"

PATH_TO_OUTPUT = os.path.normpath(OUTPUT_PATH)

Infer full path to dataset

In [5]:
_glob_struct = os.path.join(DATASET_PATH, MAGNIFICATION + "*/")
_glob_results = glob.glob(_glob_struct)

if len(_glob_results) == 0:
    raise ValueError("No path found matching glob {0}".format(_glob_struct))
elif len(_glob_results) > 1:
    from warnings import warn
    warn("Multiple paths found! Using {0}".format(_glob_results[0]))

PATH_TO_MAGNIFICATION = os.path.normpath(_glob_results[0])

In [6]:
print("Loading images from: \t", PATH_TO_MAGNIFICATION)
print("Saving results to: \t", PATH_TO_OUTPUT)

Loading images from: 	 test_data\60x images
Saving results to: 	 models


## 2. Load train data

We define a data pipeline for loading images from storage. This uses DeepTrack 2.0, and follows the structure of

1. Load each z-slice of a well-site combination and concatenate them.
2. Pad the volume such that the first two dimensions are multiples of 32 (required by the model).
3. Correct for misalignment of the fluorescence channel and the brightfield channel (by a pre-calculated parametrization of the offset as a function of magnification and the site).

### 2.1 Find all wells and sites

We create an iterator over each well and site. `Itertools.product` produces an iterator over each combination of its input. In this case, each site in each well.

In [7]:
file_list = glob.glob(os.path.join(DATASET_PATH, "*C01.tif"))

SITES = [re.findall("F([0-9]{3})", f)[-1] for f in file_list]
WELLS = [re.findall("_([A-Z][0-9]{2})_", f)[-1] for f in file_list]

wells_and_sites = list(zip(WELLS, SITES))

### 2.2 The root feature

We use DeepTrack 2.0 to define the data loader pipeline. The pipeline is a sequence of `features`, which perform computations, controlled by `properties`, which are defined when creating the features. (Note that we any property with any name and value to a feature; if a property is not used by the feature, we refer to it as a dummy property.)

The feature `root` is a `DummyFeature`, which is just a container of dummy properties and does not perform any computations.
It takes the following arguments:

* `well_site_tuple` is a dummy property that cycles through the well-site combinations in `wells_and_sites`
* `well` is a dummy property that extracts the well from the `well_site_tuple`
* `site` is a dummy property that extracts the site from the `well_site_tuple`

Note that `well` and `site` are functions that take `well_site_tuple` as argument. These are dependent properties, and DeepTrack 2.0 will automatically ensure that they receive the correct input.

In [9]:
root = dt.DummyFeature(
    well_site_tuple=itertools.cycle(wells_and_sites), # On each update, root will grab the next value from this iterator
    well=lambda well_site_tuple: well_site_tuple[0],  # Grabs the well from the well_site_tuple
    site=lambda well_site_tuple: well_site_tuple[1],  # Grabs the site from the well_site_tuple
)

### 2.3 The brightfield image loader

We use `deeptrack.LoadImage` to load and concatenate a brightfield stack. It takes the following arguments:

* `**root.properties` means that we take the properties of `root` (of importance `well` and `site`). The other properties of LoadImage will now depend on these.
* `file_names` is a dummy property, which takes the current well and site as input, and creates a list of file names that we want to load.
* `path` is a property used by `LoadImage` to determine which files to load. We calculate it by taking `file_names` as input and returning a list of paths using `os.path.join`.

Since `path` is a list, `LoadImage` stacks the images along the last dimension, creating a shaped volume with dimensions (width, height, 7).

In [10]:
brightfield_loader = dt.LoadImage(
    **root.properties,
    file_name=lambda well, site: file_name_struct.format(well, site, 4, 1),
    path=lambda file_name: os.path.join(PATH_TO_MAGNIFICATION, file_name),
)

### 2.4 The fluorescence image loader

We use `deeptrack.LoadImage` to load and concatenate a fluorescence stack. It takes the following arguments:

* `**root.properties` means that we take the properties of `root` (of importance `well` and `site`). The other properties of LoadImage will now depend on these.
* `file_names` is a dummy property, which takes the current well and site as input, and creates a list of file names that we want to load.
* `path` is a property used by `LoadImage` to determine which files to load. We calculate it by taking `file_names` as input and returning a list of paths using `os.path.join`.

Since `path` is a list, `LoadImage` stacks the images along the last dimension, creating a shaped volume with dimensions (width, height, 3).

In [11]:
fluorescence_loader = dt.LoadImage(
    **root.properties,
    file_name=lambda well, site: file_name_struct.format(well, site, 2, 1),
    path=lambda file_name: os.path.join(PATH_TO_MAGNIFICATION, file_name),
)

In [14]:
corrected_brightfield = brightfield_loader + correct_offset

data_pair = dt.Combine([corrected_brightfield, fluorescence_loader]) + dt.AsType("float64")

## 3. Define generator

We use generators to interface DeepTrack 2.0 features with Keras training routines. In DeepTrack 2.0, we have defined some special generators that speed up training. Here, we will use `deeptrack.ContinuousGenerator`, which continuosly geenrate augmented training images and makes them available for training the neural network model.

In [15]:
def fit(X, a, b, c, d):
    x = X[0]
    y = X[1]
    return a * np.exp(-((x - b) ** 2 + (y - c) ** 2) / (2 * d ** 2))

def calculate_total_error(num_files):
    all_out = []
    
    dat = np.zeros((4, 4))
    dbt = np.zeros((4, 4))
    for iii in range(num_files):
        da = np.zeros((4, 4))
        db = np.zeros((4, 4))
        im_1, im_2 = data_pair.update().resolve()
        
        patch_size=512
        res = []
        countx = 1
        _x = np.arange(-patch_size / 2, patch_size / 2)

        X, Y = np.meshgrid(_x, _x)
        Xsmall = X[
            patch_size // 2 - 10 : patch_size // 2 + 11,
            patch_size // 2 - 10 : patch_size // 2 + 11,
        ]
        Ysmall = Y[
            patch_size // 2 - 10 : patch_size // 2 + 11,
            patch_size // 2 - 10 : patch_size // 2 + 11,
        ]
        
        
        

        for i in range(0, im_1.shape[0] - patch_size, patch_size):
            county = 0
            for j in range(0, im_1.shape[1] - patch_size, patch_size):
                for z in range(1):
                    #     county=1;
                    #     for j=1:patch_size:size(image_1,2)-patch_size

                    
                    corr = np.fft.fftshift(
                        np.fft.ifft2(
                            np.fft.fft2(im_1[i : i + patch_size, j : j + patch_size, 0])
                            * np.conjugate(
                                np.fft.fft2(
                                    im_2[i : i + patch_size, j : j + patch_size, 1]
                                )
                            )
                        )
                    )
    

                    corrsmall = np.abs(
                        corr[
                            patch_size // 2 - 10 : patch_size // 2 + 11,
                            patch_size // 2 - 10 : patch_size // 2 + 11,
                        ]
                    )


                    m = np.max(corrsmall)
                    try:
                        opt, p = optimize.curve_fit(
                            fit,
                            np.array([Xsmall.flatten(), Ysmall.flatten()]),
                            corrsmall.flatten(),
                            [m, 0, 0, 5],
                        )

                        a = opt[1]
                        b = opt[2]
                        res.append((a, b))

                    except Exception as e:
#                         print(e)
                        pass
        all_out.append((np.median(res, axis=0))
    
    return all_out

In [None]:
patience = 10
Bx = Ax = x = By = Ay = y = 0

results = []
scores = []

scale=1
upper = 1.002
lower = 0.998

step_length = 0.0002

while len(results) == 0 or len(scores) - np.argmin(scores) < patience:
    

    
    X, A, B = offset(48, 
        Bx= Bx,
        Ax= Ax,
        x=  x,
        By= By,
        Ay= Ay,
        y=  y,
        scale=scale)
    
    mDA = np.mean(A[:4, :3] - A[:4, 1:4])
    mDB = np.mean(B[:3, :4] - B[1:4, :4])
    print(scale, mDA, mDB)
    
    if mDA + mDB < 0:
        scale -= step_length
        
    else:
        scale += step_length
        
        
    step_length = step_length * 0.95
    
    x2, y2 = list(zip(*X))
    
    
    ang = np.linspace(0, 2 * np.pi / 6 * len(x2), len(x2))
    def func(x, a, b, c):
        return a * np.cos(x) + b * np.sin(x) + c

    score = np.sum(np.abs(X))
    
    scores.append(score)
    results.append((Bx, Ax, x, By, Ay, y))
    
    x2n = func(ang, Bx, Ax, x) + x2
    y2n = func(ang, By, Ay, y) + y2
    
    Bx, Ax, x = optimize.curve_fit(func, ang, x2n, [0, 0, 0])[0]
    By, Ay, y = optimize.curve_fit(func, ang, y2n, [0, 0, 0])[0]
    
    plt.plot(X)
    plt.plot(func(ang, Bx, Ax, x), linestyle=":")
    plt.plot(func(ang, By, Ay, y), linestyle=":")
    plt.show()

    print(score, len(scores) - np.argmin(scores))