<a href="https://colab.research.google.com/github/AdaptiveMotorControlLab/CellSeg3d/blob/main/notebooks/colab_inference_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **CellSeg3D : inference demo notebook**

---
This notebook is part of the [CellSeg3D project](https://github.com/AdaptiveMotorControlLab/CellSeg3d) in the [Mathis Lab of Adaptive Intelligence](https://www.mackenziemathislab.org/).

- 💜 The foundation of this notebook owes much to the **[ZeroCostDL4Mic](https://github.com/HenriquesLab/ZeroCostDL4Mic)** project and to the **[DeepLabCut](https://github.com/DeepLabCut/DeepLabCut)** team for bringing Colab into scientific open software.

# **1. Installing dependencies**
---

## **1.1 Installing CellSeg3D**
---

In [None]:
#@markdown ##Install CellSeg3D and dependencies
!git clone https://github.com/AdaptiveMotorControlLab/CellSeg3d.git --branch main --single-branch ./CellSeg3D
!pip install -e CellSeg3D

## **1.2. Restart your runtime**
---
<font size = 4>


**<font size = 4> Please ignore the subsequent error message. An automatic restart of your Runtime is expected and is part of the process.**

<img width="40%" alt ="" src="https://github.com/HenriquesLab/ZeroCostDL4Mic/raw/master/Wiki_files/session_crash.png"><figcaption>  </figcaption>

In [None]:
# @title Force session restart
exit(0)

## **1.3 Load key dependencies**
---

In [None]:
# @title Load libraries
from pathlib import Path
from tifffile import imread
from napari_cellseg3d.dev_scripts import remote_inference as cs3d
from napari_cellseg3d.utils import LOGGER as logger
import logging

logger.setLevel(logging.INFO)

# **2. Inference**
---


## **2.1. Check for GPU access**
---

By default, this session is configured to use Python 3 and GPU acceleration. To verify or adjust these settings:

<font size = 4>Navigate to Runtime and select Change the Runtime type.

<font size = 4>For Runtime type, ensure it's set to Python 3 (the programming language this program is written in).

<font size = 4>Under Accelerator, choose GPU (Graphics Processing Unit).


In [None]:
#@markdown This cell verifies if GPU access is available.

import torch
if not torch.cuda.is_available():
  print('You do not have GPU access.')
  print('Did you change your runtime?')
  print('If the runtime setting is correct then Google did not allocate a GPU for your session')
  print('Expect slow performance. To access GPU try reconnecting later')

else:
  print('You have GPU access')
  !nvidia-smi


## **2.2 Run inference**
---

In [None]:
# @title Load demo image and inference configuration
#@markdown This cell loads a demo image and load the inference configuration.
demo_image_path = "./CellSeg3D/examples/c5image.tif
demo_image = imread(demo_image_path)
inference_config = cs3d.CONFIG
post_process_config = cs3d.PostProcessConfig()
# select cle device for colab
import pyclesperanto_prototype as cle
cle.select_device("cupy")

In [None]:
# @title Run inference on demo image
#@markdown This cell runs the inference on the demo image.
result = cs3d.inference_on_images(
    demo_image,
    config=inference_config,
)

In [None]:
# @title Post-process the result
# @markdown This cell post-processes the result of the inference : thresholding, instance segmentation, and statistics.
instance_segmentation,stats = cs3d.post_processing(
    result[0].semantic_segmentation,
    config=post_process_config,
)

In [None]:
# @title Display the result
#@markdown This cell displays the result of the inference and post-processing. Use the slider to navigate through the z-stack.
# @markdown *KNOWN ISSUE* : The colormap of the labels is not consistent between the z-stacks. 
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display
import matplotlib
import colorsys
import numpy as np

def random_label_cmap(n=2**16, h = (0,1), l = (.4,1), s =(.2,.8)):
    """FUNCTION TAKEN FROM STARDIST REPO : https://github.com/stardist/stardist/blob/c6c261081c6f9717fa9f5c47720ad2d5a9153224/stardist/plot/plot.py#L8"""
    h,l,s = np.random.uniform(*h,n), np.random.uniform(*l,n), np.random.uniform(*s,n)
    cols = np.stack([colorsys.hls_to_rgb(_h,_l,_s) for _h,_l,_s in zip(h,l,s)],axis=0)
    cols[0] = 0
    # reset the random generator to the first draw to keep the colormap consistent

    return matplotlib.colors.ListedColormap(cols)

label_cmap = random_label_cmap(n=instance_segmentation.max()+1)

def update_plot(z):
    plt.figure(figsize=(15, 15))
    plt.subplot(1, 3, 1)
    plt.imshow(demo_image[z], cmap='gray')
    plt.subplot(1, 3, 2)
    plt.imshow(result[0].semantic_segmentation[z], cmap='turbo')
    plt.subplot(1, 3, 3)
    plt.imshow(instance_segmentation[z], cmap=label_cmap)
    plt.show()

# Create a slider
z_slider = widgets.IntSlider(min=0, max=demo_image.shape[0]-1, step=1, value=demo_image.shape[0] // 2)

# Display the slider and update the plot when the slider is changed
widgets.interact(update_plot, z=z_slider)

In [None]:
# @title Display the statistics
# @markdown This cell displays the statistics of the post-processed result.
import pandas as pd
data = pd.DataFrame(stats.get_dict())
display(data)

In [None]:
# @title Plot the a 3D view, with statistics
# @markdown This cell plots a 3D view of the cells, with the volume as the size of the points and the sphericity as the color.
import plotly.graph_objects as go
import numpy as np

def plotly_cells_stats(data):

    x = data["Centroid x"]
    y = data["Centroid y"]
    z = data["Centroid z"]

    fig = go.Figure(
        data=go.Scatter3d(
            x=np.floor(x),
            y=np.floor(y),
            z=np.floor(z),
            mode="markers",
            marker=dict(
                sizemode="diameter",
                sizeref=30,
                sizemin=20,
                size=data["Volume"],
                color=data["Sphericity (axes)"],
                colorscale="Turbo_r",
                colorbar_title="Sphericity",
                line_color="rgb(140, 140, 170)",
            ),
        )
    )

    fig.update_layout(
        height=600,
        width=600,
        title=f'Total number of cells : {int(data["Number objects"][0])}',
    )

    fig.show(renderer="colab")
    
plotly_cells_stats(data)