In [None]:
import sys
sys.path.append("..")

# DeepTrack - Tracking multiple particles with a U-net

This notebook demonstrates how to use track multipte particles using a U-net with DeepTrack.

Specifically, this totorial explains how to:

The model receives as an input an image that may or may not contain particles and outputs an image of the same shape, with pixel-values between 0 and 1. A pixel with a value closer to one indicate a high confidence that there is a particle close to that pixel, while values close to zero inicate a high confidence that there is no nearby particle.

This tutorial should be read after the tutorial [tracking_point_particle](tracking_point_particle.ipynb).

## 1. Setup

Imports needed for this tutorial.

In [None]:
from deeptrack.scatterers import PointParticle
from deeptrack.optics import OpticalDevice
from deeptrack.image import Image
from deeptrack.noises import Poisson, Offset
from deeptrack.generators import Generator
from deeptrack.models import unet
from deeptrack.losses import weighted_crossentropy, sigmoid, flatten

# TBD: check that all these are really necessary

import numpy as np
import matplotlib.pyplot as plt

## 2. Define the particle

For this example, we consider point particles (i.e. point light scatterers). A point particle is an instance of the class PointParticle, defined by its intensity and its position. Here, the position is randomized by using a lambda function. More details can be found in the tutorial [tracking_point_particle](tracking_point_particle.ipynb).

In [None]:
particle = PointParticle(                                         
    intensity=100,
    position=lambda: np.random.rand(2) * 256,
    position_unit="pixel"
)

## 3. Define the optics 

Next, we need to define the properties of the optical system. This is done using an instance of the class Optics, which takes a set of light scatterers (particles) and convolves them with the pupil function (point spread function) of the optical system. More details can be found in the tutorial [tracking_point_particle](tracking_point_particle.ipynb).

In [None]:
Optics = OpticalDevice(
    NA=0.7,                
    pixel_size=0.1e-6,     
    wavelength=680e-9
)

## 4. Define noises

We introduce two sources of noise:
1. A Poisson noise with SNR between 10 and 30.
2. A background offset between 0 and 20.

In [None]:
poisson_noise = Poisson(
    snr=np.linspace(10,30)
)

offset = Offset(
    offset=lambda: np.random.rand()*20
)

## 5. Define the image features

We want images with a random number of particles between 1 and 10, a background offset, and Poisson noise.

In [None]:
num_particles = list(range(1,10)) 

image_features = Optics(particle**num_particles + offset) + poisson_noise

## 6. Plot example images

Now, we visualize some example images. At each iteration, we call the method `.update()` to refresh the random features in the image (particle number, particle position, offset level, and Poisson noise). Afterwards we call the method `.plot((256, 256))` to generate the image, where `(256, 256)` defines the size of the image and the input image is implicitly set to `np.zeros((256, 256))`.

In [None]:
for i in range(4):
    image_features.update()
    output_image = image_features.plot((256,256))

## 7. Create the target images

We now use the generated images to crete the target images to be used in the training.

In [None]:
def get_target_image(image):
    label = np.zeros(image.shape)
    X, Y = np.meshgrid(
        np.arange(0, image.shape[0]), 
        np.arange(0, image.shape[1])
    )

    for property in image.properties:
        if property["name"] == "PointParticle":
            position = property["position"]

            distance_map = (X - position[0])**2 + (Y - position[1])**2
            label[distance_map < 4] = 1
    
    return np.expand_dims(label,axis=-1)

# TBD: Here, it'd be good to add some code to plot images and targets next to each other

## 8. Define image generator

TBD: add comments -- I think we can eliminate the validation_generator, unless there is a reason to keep it that I don't see

In [None]:
generator = Generator().generate(
    image_features, 
    get_target_image, 
    shape=(256, 256), 
    batch_size=16
)

validation_generator = Generator().generate(
    image_features, 
    get_target_image, 
    shape=(256, 256), 
    batch_size=1)

## 9. Define the neural network model

TBD

In [None]:
model = unet(
    (256, 256, 1), 
    conv_layers_dimensions=[16, 32, 32, 32], 
    loss=flatten(weighted_crossentropy((90, 1)))
)

In [None]:

import tensorflow.keras.backend as K
from tensorflow import keras
        


## 10. Train the model

The model is trained by calling `.fit()`.

In [None]:
model.fit(
    generator, 
    epochs=20, 
    steps_per_epoch=10
)

## 11. Visualize the model performance

TBD

In [None]:
for i in range(5):
    input_image, target_image = next(generator)
    
    predicted_image = model.predict(image)
    
    plt.subplot(1,3,1)
    plt.imshow(np.squeeze(image[0, :, :, 0]))
    plt.title("Input Image")

    plt.subplot(1,3,2)
    plt.imshow(np.squeeze(predicted_image[0, :, :, 0]))
    plt.title("Predicted Image")
    
    plt.subplot(1,3,3)
    plt.imshow(np.squeeze(l[0, :, :, 0] > 0.5))
    plt.title("Target Image (Ground Truth)")

    plt.show()