In [ ]:
%matplotlib inline
import sys
sys.path.append("..")

# DeepTrack - Tracking multiple particles with a U-net

This notebook demonstrates how to use track multipte particles using a U-net with DeepTrack.

Specifically, this totorial explains how to:

The model receives as an input an image that may or may not contain particles and outputs an image of the same shape, with pixel-values between 0 and 1. A pixel with a value closer to one indicate a high confidence that there is a particle close to that pixel, while values close to zero inicate a high confidence that there is no nearby particle.

This tutorial should be read after the tutorial [tracking_point_particle](tracking_point_particle.ipynb).

## 1. Setup

Imports needed for this tutorial.

In [ ]:
from deeptrack.scatterers import PointParticle
from deeptrack.optics import OpticalDevice
from deeptrack.noises import Poisson, Offset
from deeptrack.generators import Generator
from deeptrack.models import unet
from deeptrack.losses import weighted_crossentropy, sigmoid, flatten

import numpy as np
import matplotlib.pyplot as plt

## 2. Define the particle

For this example, we consider point particles (i.e. point light scatterers). A point particle is an instance of the class PointParticle, defined by its intensity and its position. Here, the position is randomized by using a lambda function. More details can be found in the tutorial [tracking_point_particle](tracking_point_particle.ipynb).

In [ ]:
particle = PointParticle(                                         
    intensity=100,
    position=lambda: np.random.rand(2) * 256,
    position_unit="pixel"
)

## 3. Define the optics 

Next, we need to define the properties of the optical system. This is done using an instance of the class Optics, which takes a set of light scatterers (particles) and convolves them with the pupil function (point spread function) of the optical system. More details can be found in the tutorial [tracking_point_particle](tracking_point_particle.ipynb).

In [ ]:
Optics = OpticalDevice(
    NA=0.7,                
    pixel_size=0.1e-6,     
    wavelength=680e-9
)

## 4. Define noises

We introduce two sources of noise:
1. A Poisson noise with SNR between 10 and 30.
2. A background offset between 0 and 20.

In [ ]:
poisson_noise = Poisson(
    snr=np.linspace(30,50)
)

offset = Offset(
    offset=lambda: np.random.rand()*20
)

## 5. Define the image features

We want images with a random number of particles between 1 and 10, a background offset, and Poisson noise.

In [ ]:
num_particles = np.random.randint(1,11)

image_features = Optics(particle**num_particles + offset) + poisson_noise

## 6. Plot example images

Now, we visualize some example images. At each iteration, we call the method `.update()` to refresh the random features in the image (particle number, particle position, offset level, and Poisson noise). Afterwards we call the method `.plot((256, 256))` to generate the image, where `(256, 256)` defines the size of the image and the input image is implicitly set to `np.zeros((256, 256))`.

In [ ]:
for i in range(4):
    image_features.update()
    output_image = image_features.plot((256,256), cmap="gray")

## 7. Create the target images

We define a function that uses the generated images to create the target images to be used in the training. We also show images and targets side by side.

In [ ]:
# Creates an image with circles of radius two at the same position 
# as the particles in the input image.
def get_target_image(image_of_particles):
    label = np.zeros(image_of_particles.shape)
    X, Y = np.meshgrid(
        np.arange(0, image_of_particles.shape[0]), 
        np.arange(0, image_of_particles.shape[1])
    )

    for property in image_of_particles.properties:
        if property["name"] == "PointParticle":
            position = property["position"]

            distance_map = (X - position[0])**2 + (Y - position[1])**2
            label[distance_map < 4] = 1
    
    return np.expand_dims(label,axis=-1)


input_image = np.zeros((256,256))
for i in range(4):
    image_features.update()
    image_of_particles = image_features.resolve(input_image)

    target_image = get_target_image(image_of_particles)

    plt.subplot(1,2,1)
    plt.imshow(image_of_particles, cmap="gray")
    plt.title("Input Image")
    plt.subplot(1,2,2)
    plt.imshow(target_image[:,:,0], cmap="gray")
    plt.title("Ground Truth")
    plt.show()

## 8. Define image generator

We define a generator that creates images and targets of shape (256, 256) and in batches of 8.

In [ ]:
generator = Generator().generate(
    image_features, 
    get_target_image, 
    shape=(256, 256), 
    batch_size=8
)

## 9. Define the neural network model

The neural network architecture used is a U-Net, which is a fully convoltional model used for image to image transformations. Since the desired output is a binary image, we will be using crossentropy as loss. However, since target image is disproportionaly populated by 0s (any pixel is much more likely to be a zero than a one), we weight the loss such that false negatives are penalized ten times more than false positives. 

In [ ]:
model = unet(
    (256, 256, 1), 
    conv_layers_dimensions=[8, 16, 32],
    base_conv_layers_dimensions=[32, 32], 
    loss=flatten(weighted_crossentropy((10, 1)))
)

## 10. Train the model

The model is trained by calling `.fit()`.

In [ ]:
model.fit(
    generator, 
    epochs=50, 
    steps_per_epoch=20
)

## 11. Visualize the model performance

Finally we evaluate the model performance by showing the model output besides the input image and the ground truth.

In [ ]:
input_image, target_image = next(generator)

for i in range(input_image.shape[0]):
    
    predicted_image = model.predict(input_image)
    
    plt.subplot(1,3,1)
    plt.imshow(np.squeeze(input_image[i, :, :, 0]), cmap="gray")
    plt.title("Input Image")

    plt.subplot(1,3,2)
    plt.imshow(np.squeeze(predicted_image[i, :, :, 0]), cmap="gray")
    plt.title("Predicted Image")
    
    plt.subplot(1,3,3)
    plt.imshow(np.squeeze(target_image[i, :, :, 0] > 0.5), cmap="gray")
    plt.title("Ground Truth")

    plt.show()