# Segmentation of copper nanoparticles in TEM images using CellPose

In this notebook, we will analyze the growth of Cu nanoparticles in TEM timeseries.

We will use

- `Numpy` to denoise the timeseries
- `CellPose` to segment the particles
- `Scikit-image` to measure particle size
- `Napari` to visualize the results

```{admonition} Acknowledgements
This notebook was adapted from a [project](https://gitlab.com/epfl-center-for-imaging/nanoparticles-segmentation) that was part of a collaboration between the EPFL Center for Imaging and the Laboratory for in situ Nanomaterials Characterization with Electrons.
```

In [None]:
import numpy as np
from magicgui import magicgui
import napari
from napari.layers import Image, Labels
from napari.types import LayerDataTuple
from skimage.measure import regionprops_table
from skimage.io import imread
from cellpose import models
from cellpose.utils import stitch3D
from magicgui.tqdm import trange

## Load the timeseries image

The image axes should be in `TYX` order.

In [None]:
image = imread('./data/new_dataset_downscaled_fiji-2-cropped.tif')

image.shape

### Load the image into Napari

Napari will help us process the image interactively and visualize the results of our analysis.

In [None]:
viewer = napari.Viewer()

viewer.add_image(image)

## Temporal averaging for denoising

To denoise the image, we run a temporal average over a given frame number. We implement this functionality in a `dock widget` for Napari using the `magicgui` library.

In [None]:
@magicgui(
    call_button='Run temporal average',
    layout='vertical',
    image_layer={'name': 'Image'},
    Frames_number={"widget_type": "SpinBox", "min": 1},
)
def temporal_averaging(image_layer: Image, Frames_number: int=1) -> LayerDataTuple:
    if image_layer is None:
        return

    image = image_layer.data

    if Frames_number > len(image):
        Frames_number = len(image)

    filt = np.array([1] * Frames_number) / Frames_number
    t, x, y = image.shape
    im_rs = image.reshape((t, x*y))
    tf = np.apply_along_axis(
        lambda x: np.convolve(filt, x, mode='same'), 
        arr=im_rs, 
        axis=0
    )
    time_filtered = tf.reshape((t, x, y))

    return (time_filtered, {'name': 'denoised'}, 'image')


viewer.window.add_dock_widget(temporal_averaging, name='Denoising (temporal average)')

## Segmentation of Copper particles using `CellPose`

CellPose provides a generalist segmentation model capable of detecting and segmenting cell nuclei, but also many instances of circular objects in scientific images. In the cell below, we implement an interactive segmentation functionality in Napari based on CellPose. 

In [None]:
@magicgui(call_button='Run detection', image_layer={'name': 'Image'})
def cellpose_dock_widget(image_layer: Image, use_gpu: bool=True) -> LayerDataTuple:
    if image_layer is None:
        return
    
    image = image_layer.data

    model = models.Cellpose(gpu=use_gpu, model_type='cyto')

    cellpose_labels = np.empty(image.shape)
    for z_level in trange(len(image)):
        frame = image[z_level]
        cellpose_mask, *_ = model.eval(frame, diameter=None, flow_threshold=None)
        cellpose_labels[z_level] = cellpose_mask
    cellpose_labels = cellpose_labels.astype('int')
    cellpose_labels = stitch3D(cellpose_labels, stitch_threshold=0.5)

    return (cellpose_labels, {'name': 'cellpose', 'opacity': 0.4}, 'labels')

viewer.window.add_dock_widget(cellpose_dock_widget, name='Cellpose (all frames)')

## Computing and plotting particle size

We use Scikit-image's `regionprops` function to compute the particle size and display it overlayed on the segmentation using a Napari `Points` layer.

In [None]:
@magicgui(call_button='Measure particle size', layout='vertical', labels_layer={'name': 'Detections'})
def region_properties_widget(labels_layer: Labels) -> LayerDataTuple:
    if labels_layer is None:
        return
    
    labels = labels_layer.data

    points = []
    areas = []
    times = []
    labels_idx = []
    for k, label_frame in enumerate(labels):
        props = regionprops_table(label_frame, properties=['label', 'centroid', 'area'])
        for cx, cy, area, lab in zip(props['centroid-0'], props['centroid-1'], props['area'], props['label']):
            points.append([k, cx, cy])
            areas.append(area)
            times.append(k)
            labels_idx.append(lab)

    points_props = {
        'name': 'area (px)', 
        'face_color': 'blue', 
        'edge_width': 0, 
        'size': 0,
        'features': {'area': areas, 'time': times, 'label': labels_idx}, 
        'text': {
            'string': '{label:0.0f}',
            'size': 9,
            'color': 'white'
        }
    }

    return (points, points_props, 'points')

viewer.window.add_dock_widget(region_properties_widget, name='Measure')