<a href="https://colab.research.google.com/github/AdaptiveMotorControlLab/CellSeg3d/blob/main/notebooks/Colab_inference_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **CellSeg3D : inference demo notebook**

---
This notebook is part of the [CellSeg3D project](https://github.com/AdaptiveMotorControlLab/CellSeg3d) in the [Mathis Lab of Adaptive Intelligence](https://www.mackenziemathislab.org/).

- 💜 The foundation of this notebook owes much to the **[ZeroCostDL4Mic](https://github.com/HenriquesLab/ZeroCostDL4Mic)** project and to the **[DeepLabCut](https://github.com/DeepLabCut/DeepLabCut)** team.

# **1. Installing dependencies**
---

## **1.1 Installing CellSeg3D**
---

In [1]:
#@markdown ##Install CellSeg3D and grab demo data
!git clone https://github.com/AdaptiveMotorControlLab/CellSeg3d.git --branch main --single-branch ./CellSeg3D
!pip install napari-cellseg3d
!pip install pyClesperanto




## **1.3 Load key dependencies**
---

In [2]:
# @title Load libraries
import napari_cellseg3d
from tifffile import imread
from pathlib import Path
from napari_cellseg3d.dev_scripts import remote_inference as cs3d
from napari_cellseg3d.utils import LOGGER as logger
from napari_cellseg3d.config import MODEL_LIST, ModelInfo

import logging

logger.setLevel(logging.INFO)

  warn("Unable to import recommended hash 'siphash24.siphash13', "


# **2. Inference**
---


## **2.1. Check for GPU access**
---

By default, this session is configured to use Python 3 and GPU acceleration. To verify or adjust these settings:

<font size = 4>Navigate to Runtime and select Change the Runtime type.

<font size = 4>For Runtime type, ensure it's set to Python 3 (the programming language this program is written in).

<font size = 4>Under Accelerator, choose GPU (Graphics Processing Unit).


In [3]:
#@markdown This cell verifies if GPU access is available.

import torch
if not torch.cuda.is_available():
  print('You do not have GPU access.')
  print('Did you change your runtime?')
  print('If the runtime setting is correct then Google did not allocate a GPU for your session')
  print('Expect slow performance. To access GPU try reconnecting later')

else:
  print('You have GPU access')
  !nvidia-smi


You have GPU access
Sun Dec 15 21:09:57 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   46C    P8               9W /  70W |      3MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                

## **2.2 Run inference**
---

In [None]:
# Write a Colab dropdown menu to choose the model from MODEL_LIST

import ipywidgets as widgets
from IPython.display import display

model_list = [model for model in MODEL_LIST.keys()]

model_dropdown = widgets.Dropdown(
    options=model_list,
    description='Model:',
    disabled=False,
)

display(model_dropdown)

In [4]:
demo_image_path = "/content/CellSeg3D/examples/c5image.tif"
demo_image = imread(demo_image_path)
inference_config = cs3d.CONFIG
inference_config.model_info = ModelInfo(
    name=model_dropdown.value,
    model_input_size=[64, 64, 64],
    num_classes=2,
)
post_process_config = cs3d.PostProcessConfig(threshold=MODEL_LIST[model_dropdown.value].default_threshold)
# select cle device for colab
import pyclesperanto_prototype as cle
cle.select_device("cupy")

'cupy backend (experimental)'

In [5]:
result = cs3d.inference_on_images(
    demo_image,
    config=inference_config,
)

--------------------
Parameters summary :
Model is : SwinUNetR
Window inference is enabled
Window size is 64
Window overlap is 0.25
Dataset loaded on cuda device
--------------------
MODEL DIMS : 64
Model name : SwinUNetR
Instantiating model...


monai.networks.nets.swin_unetr SwinUNETR.__init__:img_size: Argument `img_size` has been deprecated since version 1.3. It will be removed in version 1.5. The img_size argument is not required anymore and checks on the input size are run during forward().
INFO:napari_cellseg3d.utils:********************
INFO:napari_cellseg3d.utils:Weight file SwinUNetR_latest.pth already exists, skipping download


Loading weights...


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.


Weights status : <All keys matched successfully>
Done
--------------------
Parameters summary :
Model is : SwinUNetR
Window inference is enabled
Window size is 64
Window overlap is 0.25
Dataset loaded on cuda device
--------------------
Loading layer
2024-12-15 21:10:06,566 - INFO - Apply pending transforms - lazy: False, pending: 0, upcoming 'QuantileNormalization', transform is not lazy
2024-12-15 21:10:06,592 - INFO - Apply pending transforms - lazy: False, pending: 0, upcoming 'ToTensor', transform is not lazy
2024-12-15 21:10:06,595 - INFO - Apply pending transforms - lazy: False, pending: 0, upcoming 'EnsureType', transform is not lazy
Done
----------
Inference started on layer...
Post-processing...
Layer prediction saved as : volume_SwinUNetR_pred_1_2024_12_15_21_10_09


In [6]:
# @title Post-process the result
# @markdown This cell post-processes the result of the inference : thresholding, instance segmentation, and statistics.
instance_segmentation,stats = cs3d.post_processing(
    result[0].semantic_segmentation,
    config=post_process_config,
)

1it [00:00, 11.29it/s]
clesperanto's cupy / CUDA backend is experimental. Please use it with care. The following functions are known to cause issues in the CUDA backend:
affine_transform, apply_vector_field, create(uint64), create(int32), create(int64), resample, scale, spots_to_pointlist
divide by zero encountered in scalar divide
invalid value encountered in scalar multiply


In [7]:
# @title Display the result
#@markdown This cell displays the result of the inference and post-processing. Use the slider to navigate through the z-stack.
# @markdown *KNOWN ISSUE* : The colormap of the labels is not consistent between the z-stacks.
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display
import matplotlib
import colorsys
import numpy as np

def random_label_cmap(n=2**16, h = (0,1), l = (.4,1), s =(.2,.8)):
    """FUNCTION TAKEN FROM STARDIST REPO : https://github.com/stardist/stardist/blob/c6c261081c6f9717fa9f5c47720ad2d5a9153224/stardist/plot/plot.py#L8"""
    h,l,s = np.random.uniform(*h,n), np.random.uniform(*l,n), np.random.uniform(*s,n)
    cols = np.stack([colorsys.hls_to_rgb(_h,_l,_s) for _h,_l,_s in zip(h,l,s)],axis=0)
    cols[0] = 0
    # reset the random generator to the first draw to keep the colormap consistent

    return matplotlib.colors.ListedColormap(cols)

label_cmap = random_label_cmap(n=instance_segmentation.max()+1)

def update_plot(z):
    plt.figure(figsize=(15, 15))
    plt.subplot(1, 3, 1)
    plt.imshow(demo_image[z], cmap='gray')
    plt.subplot(1, 3, 2)
    plt.imshow(result[0].semantic_segmentation[z], cmap='turbo')
    plt.subplot(1, 3, 3)
    plt.imshow(instance_segmentation[z], cmap=label_cmap)
    plt.show()

# Create a slider
z_slider = widgets.IntSlider(min=0, max=demo_image.shape[0]-1, step=1, value=demo_image.shape[0] // 2)

# Display the slider and update the plot when the slider is changed
widgets.interact(update_plot, z=z_slider)

interactive(children=(IntSlider(value=62, description='z', max=123), Output()), _dom_classes=('widget-interact…

In [8]:
# @title Display the statistics
# @markdown This cell displays the statistics of the post-processed result.
import pandas as pd
data = pd.DataFrame(stats.get_dict())
display(data)

Unnamed: 0,Volume,Centroid x,Centroid y,Centroid z,Sphericity (axes),Image size,Total image volume,Total object volume (pixels),Filling ratio,Number objects
0,190.0,5.405263,69.157895,36.210526,0.778113,"(124, 86, 94)",1002416,33504.0,0.033423,322
1,18.0,5.833333,85.000000,83.944444,0.000007,,,,,
2,67.0,7.283582,65.492537,92.059701,0.867751,,,,,
3,108.0,10.324074,84.342593,68.861111,0.672490,,,,,
4,35.0,9.428571,84.314286,92.600000,0.649649,,,,,
...,...,...,...,...,...,...,...,...,...,...
317,11.0,122.363636,14.727273,25.000000,0.951651,,,,,
318,24.0,122.166667,26.083333,38.083333,0.990075,,,,,
319,16.0,122.125000,34.125000,36.500000,0.944672,,,,,
320,13.0,122.076923,43.538462,53.615385,0.939852,,,,,


In [9]:
# @title Plot the a 3D view, with statistics
# @markdown This cell plots a 3D view of the cells, with the volume as the size of the points and the sphericity as the color.
import plotly.graph_objects as go
import numpy as np

def plotly_cells_stats(data):

    x = data["Centroid x"]
    y = data["Centroid y"]
    z = data["Centroid z"]

    fig = go.Figure(
        data=go.Scatter3d(
            x=np.floor(x),
            y=np.floor(y),
            z=np.floor(z),
            mode="markers",
            marker=dict(
                sizemode="diameter",
                sizeref=30,
                sizemin=20,
                size=data["Volume"],
                color=data["Sphericity (axes)"],
                colorscale="Turbo_r",
                colorbar_title="Sphericity",
                line_color="rgb(140, 140, 170)",
            ),
        )
    )

    fig.update_layout(
        height=600,
        width=600,
        title=f'Total number of cells : {int(data["Number objects"][0])}',
    )

    fig.show(renderer="colab")

plotly_cells_stats(data)