In [None]:
%matplotlib inline
import sys
sys.path.append("..")

# DeepTrack - Tracking a point particle with a CNN

This notebook demonstrates how to track point particles with a convolutional neural network using DeepTrack.

Specifically, this tutotial explains how to: 
* Define the procedure to generate training images
* Extract information from images to use as labels
* Define and train a neural network model
* Visually evaluate the quality of the neural network output

## 1. Setup

Imports needed for this tutorial.

In [None]:
from deeptrack.scatterers import PointParticle
from deeptrack.optics import OpticalDevice
from deeptrack.generators import Generator
from deeptrack.models import convolutional

import numpy as np
import matplotlib.pyplot as plt

## 2. Define the particle

For this example, we consider a point particle (i.e. a point light scatterer). A point particle is an instance of the class PointParticle, defined by its intensity and its position

A point particle is controlled by the following parameters:

* intensity: The intensity of the point particle

* position: The position of the point particle

* position_unit: "pixel" or "meter"

In [None]:
point_particle = PointParticle(                                         
    intensity=100,
    position=(32, 16),
    position_unit="pixel"
)

## 3. Define the optics 

Next, we need to define the properties of the optical system. This is done using an instance of the class Optics, which takes a set of light scatterers (particles) and convolves them with the pupil function (point spread function) of the optical system. In this tutorial, there is only one light scatterer (here, `point_particle`).

The optics is controlled by the following parameters:

* NA: The numerical aperature

* pixel_size: The pixel to meter conversion factor (m/px)

* wavelength: The wavelength of the lightsource (m)

* mode: "coherent" or "incoherent" light emitted by the object

* ROI: Region of interest that is imaged (to avoid wrap-around effects when Fourier-tranforming)

* upscale: upscale factor for the pupil function (increases accuracy and computational cost).

In [None]:
optics = OpticalDevice(
    NA=0.7,                
    pixel_size=0.1e-6,     
    wavelength=680e-9
)

## 4. Create and plot the image

To view some object through an optical device, we call the optical device (here, `optics`), with the object we want to image (here, `point_particle`). This creates a new object (here, `imaged_particle`) that can be used to generate the desired image.

The image is finally generated by calling `imaged_particle.resolve(input_image)`, where `input_image` is an numpy array of the desired image shape (this can be seen as the background image).

In [None]:
imaged_particle = optics(point_particle)

input_image = np.zeros((64, 64))
output_image = imaged_particle.resolve(input_image)

plt.imshow(output_image, cmap='gray')
plt.show()

## 5. Randomize the particle position

We can generate particles with random positions by passing to the keyword argument `position` a lambda function that returns a pair of random numbers representing the particle position.

The position can be retrieved from the attribute `.position` of the generated image. `.properties` contains a list of all properties used to create the image.

In [None]:
# Generate particle with random position

point_particle_with_random_position = PointParticle(                                         
    intensity=100,
    position=lambda: 10 + np.random.rand(2) * 44,
    position_unit="pixel"
)

imaged_particle_with_random_position = optics(point_particle_with_random_position)

input_image = np.zeros((64, 64))
output_image = imaged_particle_with_random_position.resolve(input_image)

plt.imshow(output_image, cmap='gray')


# Retrieve particle position

def get_position_of_particle(image):
    for image_property in image.properties:
        if "position" in image_property:
            return image_property["position"]

position_of_particle = get_position_of_particle(output_image)

plt.scatter(position_of_particle[0], position_of_particle[1])
plt.show()

## 6. Define the neural network model

We will use a predefined neural network model to track the particle obtained by calling the function `convolutional`. This model is a convolutional neural network with a dense top. It receives an input of shape (64, 64, 1) and outputs two values (x and y position of the particle).

In [None]:
model = convolutional(
    input_shape=(64, 64, 1), 
    number_of_outputs=2
)

## 7. Define image generator

Generators are objects that feed models with images and their corresponging labels during training. They are created by calling `.generate()` on an instance of the class Generator. This method takes the following inputs:
* feature: A feature that resolves images used to train a model
* label_function: A function that takes an image as input and returns the label for that image
* shape: The shape of the output image
* batch_size: The number of images per batch

In [None]:
# Function that retireves the position of a particle 
# and divides it by 64 to get values between 0 and 1
def get_scaled_position_of_particle(image):
    position_of_particle = get_position_of_particle(image)
    return position_of_particle / 64

generator = Generator().generate(
    particle, 
    get_position_of_particle, 
    shape=(64, 64), 
    batch_size=4
)

## 8. Train the model

The model is trained by calling `.fit()` with the generator we defined in the previous step.

In [None]:
model.fit(
    generator,
    epochs=1000,
    steps_per_epoch=64
)

## 9. Visualize the model performance

We can now use the trained model to measure the particle position in images previously unseen by the model.

In [None]:
images, real_positions = next(generator)

mesured_positions = model.predict(images)

for i in range(images.shape[0]):
    
    image = np.squeeze(images[i])
    plt.imshow(image, cmap='gray')
    
    mesured_position_x = prediction[i, 0] * 64
    mesured_position_y = prediction[i, 1] * 64    
    plt.scatter(mesured_position_x, mesured_position_y)
    
    plt.show()