In [None]:
%matplotlib inline
import sys
sys.path.append("..")

# DeepTrack 2.0 - Tracking a point particle with a CNN

This tutorial demonstrates how to track a point particle with a convolutional neural network (CNN) using DeepTrack 2.0.

Specifically, this tutotial explains how to: 
* Define the procedure to generate training images
* Extract information from these images to use as labels for the training
* Define and train a neural network model
* Visually evaluate the quality of the neural network output

It is recommended to peruse this tutotial after the [deeptrack_introduction_tutorial](deeptrack_introduction_tutorial.ipynb).

## 1. Setup

Imports needed for this tutorial.

In [None]:
from deeptrack.scatterers import PointParticle
from deeptrack.optics import Fluorescence
from deeptrack.generators import Generator
from deeptrack.models import convolutional

import numpy as np
import matplotlib.pyplot as plt

## 2. Define the particle

For this example, we consider a point particle (a point light scatterer). A point particle is an instance of the class `PointParticle` (see also [scatterers_example](../examples/scatterers_example.ipynb)), whose properties are controlled by the following parameters:

* `intensity`: The intensity of the point particle

* `position`: The position of the point particle

* `position_unit`: "pixel" or "meter"

In [None]:
point_particle = PointParticle(                                         
    intensity=100,
    position=(32, 16),
    position_unit="pixel"
)

## 3. Define the optical system 

Next, we need to define the properties of the optical system. This is done using an instance of the class `Optics` (see also [optics_example](../examples/optics_example.ipynb)), which takes a set of particles (light scatterers) and convolves them with the pupil function (point spread function) of the optical system. In this tutorial, there is only one light scatterer (here, `point_particle`).

The optics is controlled by the following parameters:

* `NA`: The numerical aperature

* `resolution`: The effective camera pixel size (m)

* `magnification`: The magnification of the optical device

* `wavelength`: The wavelength of the lightsource (m)

* `output_region`: The position of the camera and the number of pixels (x, y, width_x, width_y)

* `upscale`: upscale factor for the pupil function (increases accuracy and computational cost)

In [None]:
IMAGE_SIZE = 64

fluorescence_microscope = Fluorescence(
    NA=0.7,                
    resolution=1e-6,     
    magnification=10,
    wavelength=680e-9,
    output_region=(0, 0, IMAGE_SIZE, IMAGE_SIZE),
    upscale=2
)

## 4. Create and plot the image

To view some object through an optical device, we call the optical device (here, `fluorescence_microscope`) with the object we want to image (here, `point_particle`). This creates a new object (here, `imaged_particle`) that can be used to generate the desired image.

The image is finally generated by calling `imaged_particle.resolve()`.

In [None]:
imaged_particle = fluorescence_microscope(point_particle)

output_image = imaged_particle.resolve()

plt.imshow(np.squeeze(output_image), cmap='gray')
plt.show()

## 5. Randomize the particle position

We can generate particles with random positions by passing to the keyword argument `position` a lambda function that returns a pair of random numbers representing the particle position.

In [None]:
# Generate particle with random position

particle_with_random_position = PointParticle(                                         
    intensity=100,
    position=lambda: np.random.rand(2) * IMAGE_SIZE,
    position_unit="pixel"
)

imaged_particle_with_random_position = fluorescence_microscope(particle_with_random_position)

output_image = imaged_particle_with_random_position.resolve()

plt.imshow(np.squeeze(output_image), cmap='gray')

The position can then be retrieved from the attribute `.position` of the generated image. `.properties` contains a list of all properties used to create the image.

In [None]:
# Retrieve particle position

def get_position_of_particle(image):
    for image_property in image.properties:
        if "position" in image_property:
            return image_property["position"]

position_of_particle = get_position_of_particle(output_image)

plt.imshow(np.squeeze(output_image), cmap='gray')
plt.scatter(position_of_particle[1], position_of_particle[0])
plt.show()

## 6. Define the neural network model

We will use a predefined neural network model obtained by calling the function `convolutional` (see also [models_example](../examples/models_example.ipynb)). This model is a convolutional neural network with a dense top. It receives as input an image of shape `(64, 64, 1)` and outputs two scalar values corresponing to the x and y position of the particle.

The model can be customized using the following arguments

* `input_shape`: Size of the images to be analyzed.

* `conv_layers_dimensions`: Number of convolutions in each convolutional layer.
    
* `dense_layers_dimensions`: Number of units in each dense layer.
        
* `number_of_outputs`: Number of units in the output layer.

* `output_activation`: The activation function applied to the output layer.

* `loss`: The loss function of the network.

* `optimizer`: The the optimizer used for training.

* `metrics`: Additional metrics to evaulate during training.

In [None]:
model = convolutional(
    input_shape=(64, 64, 1), 
    number_of_outputs=2
)

model.summary()

## 7. Define image generator

Generators are objects that feed models with images and their corresponging labels during training. They are created by calling `.generate()` on an instance of the class `Generator` (see also [generators_example](../examples/generators_example.ipynb)). This method takes the following inputs:
* `feature`: A feature (see also [features_example](../examples/features_example.ipynb) that resolves images used to train a model (here, `imaged_particle_with_random_position`)
* `label_function`: A function that takes an image as input and returns the label for that image (here, `get_position_of_particle`)
* `batch_size`: The number of images per batch

In [None]:
# Function that retireves the position of a particle 
# and divides it by 64 to get values between 0 and 1
def get_scaled_position_of_particle(image):
    position_of_particle = get_position_of_particle(image)
    return position_of_particle / 64

generator = Generator().generate(
    imaged_particle_with_random_position, 
    get_scaled_position_of_particle, 
    batch_size=4
)

## 8. Train the model

The model is trained by calling the method `.fit()` with the generator we defined in the previous step. Be patient, this might take some time (several minutes).

In [None]:
model.fit(
    generator,
    epochs=100,
    steps_per_epoch=64
)

## 9. Visualize the model performance

We can now use the trained model to measure the particle position in images previously unseen by the model.

In [None]:
images, real_positions = next(generator)

measured_positions = model.predict(images)

for i in range(images.shape[0]):
    
    image = np.squeeze(images[i])
    
    measured_position_x = measured_positions[i, 1] * IMAGE_SIZE
    measured_position_y = measured_positions[i, 0] * IMAGE_SIZE

    real_position_x = real_positions[i, 1] * IMAGE_SIZE
    real_position_y = real_positions[i, 0] * IMAGE_SIZE

    plt.imshow(image, cmap='gray')
    plt.scatter(real_position_x, real_position_y, s=70, c='r', marker='x')
    plt.scatter(measured_position_x, measured_position_y, s=100, marker='o', facecolor='none', edgecolors='b')
    plt.show()