# **WNet: 3D Unsupervised Cell Segmentation**

---
*Disclaimer:*

This notebook, part of the [CellSeg3D project](https://github.com/AdaptiveMotorControlLab/CellSeg3d) under the [Mathis Lab of Adaptive Motor Control](https://www.mackenziemathislab.org/), is a work-in-progress resource for training the WNet model for unsupervised cell segmentation.

The foundation of this notebook owes much to the **[ZeroCostDL4Mic](https://github.com/HenriquesLab/ZeroCostDL4Mic)** project —a collaborative effort between the Jacquemet and Henriques laboratories, and created by Daniel Krentzel. Except for the model provided herein, all credits are duly given to their team.

#**1. Installing dependencies**
---

##**1.1 Installing CellSeg3D**
---

In [1]:
#@markdown ##Play to install WNet dependencies
!git clone https://github.com/AdaptiveMotorControlLab/CellSeg3d.git --branch cy/wnet-extras --single-branch ./CellSeg3D
!pip install -e CellSeg3D

fatal: destination path './CellSeg3D' already exists and is not an empty directory.
Obtaining file:///content/CellSeg3D
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: napari-cellseg3d
  Building editable for napari-cellseg3d (pyproject.toml) ... [?25l[?25hdone
  Created wheel for napari-cellseg3d: filename=napari_cellseg3d-0.0.3rc1-0.editable-py3-none-any.whl size=6209 sha256=307ee3cf2e41bdc6b51500507122f9cdac4161970a01fe63986c0743dbc738f7
  Stored in directory: /tmp/pip-ephem-wheel-cache-l3d04gq9/wheels/a1/e1/dc/cc9f89fc6f907d6bd38a2cbf3335706054a5435e97f664034d
Successfully built napari-cellseg3d
Installing collected packages: napari-cellseg3d
  Attempting uninstall: napari-cellseg3d
    Found existing installation: napari-cellseg3d 0.0.

## **1.2. Restart your runtime**
---
<font size = 4>


**<font size = 4> Please ignore the subsequent error message. An automatic restart of your Runtime is expected and is part of the process.**

<img width="40%" alt ="" src="https://github.com/HenriquesLab/ZeroCostDL4Mic/raw/master/Wiki_files/session_crash.png"><figcaption>  </figcaption>

In [2]:
# @title
#Force session restart
exit(0)

##**1.3 Load key dependencies**
---

In [4]:
# @title
from pathlib import Path
from napari_cellseg3d.dev_scripts import colab_training as c
from napari_cellseg3d.config import WNetTrainingWorkerConfig, WandBConfig, WeightsInfo, PRETRAINED_WEIGHTS_DIR

DEBUG:napari_cellseg3d.utils:PRETRAINED WEIGHT DIR LOCATION : /content/CellSeg3D/napari_cellseg3d/code_models/models/pretrained
DEBUG:napari_cellseg3d.utils:PRETRAINED WEIGHT DIR LOCATION : /content/CellSeg3D/napari_cellseg3d/code_models/models/pretrained


##**1.4 Initialize Weights & Biases integration (optional)**
---

In [None]:
!pip install -q wandb
import wandb
wandb.login()

# **2. Complete the Colab session**
---
If you wish to utilize Weights & Biases (WandB) for monitoring and logging your training session, execute the cell below.
To enable it, just input your API key in the space provided.


## **2.1. Check for GPU access**
---

By default, this session is configured to use Python 3 and GPU acceleration. To verify or adjust these settings:

<font size = 4>Navigate to Runtime and select Change the Runtime type.

<font size = 4>For Runtime type, ensure it's set to Python 3 (the programming language this program is written in).

<font size = 4>Under Accelerator, choose GPU (Graphics Processing Unit).


In [1]:
#@markdown ##Execute the cell below to verify if GPU access is available.

import torch
if not torch.cuda.is_available():
  print('You do not have GPU access.')
  print('Did you change your runtime?')
  print('If the runtime setting is correct then Google did not allocate a GPU for your session')
  print('Expect slow performance. To access GPU try reconnecting later')

else:
  print('You have GPU access')
  !nvidia-smi


You have GPU access
Thu Aug  3 14:22:50 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   57C    P8    10W /  70W |      3MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-------------------------------------------------------------------

## **2.2. Mount Google Drive**
---
<font size = 4>To integrate this notebook with your personal data, save your data on Google Drive in accordance with the directory structures detailed in Section 0.

1. <font size = 4> **Run** the **cell** below and click on the provided link.

2. <font size = 4>Log in to your Google account and grant the necessary permissions by clicking 'Allow'.

3. <font size = 4>Copy the generated authorization code and paste it into the cell, then press 'Enter'. This grants Colab access to read and write data to your Google Drive.

4. <font size = 4> After completion, you can view your data in the notebook. Simply click the Files tab on the top left and select 'Refresh'.

In [6]:
#@markdown ##Play the cell to connect your Google Drive to Colab

#@markdown * Click on the URL.

#@markdown * Sign in your Google Account.

#@markdown * Copy the authorization code.

#@markdown * Enter the authorization code.

#@markdown * Click on "Files" site on the right. Refresh the site. Your Google Drive folder should now be available here as "drive".

# mount user's Google Drive to Google Colab.
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


**<font size = 4> If you cannot see your files, reactivate your session by connecting to your hosted runtime.**


<img width="40%" alt ="Example of image detection with retinanet." src="https://github.com/HenriquesLab/ZeroCostDL4Mic/raw/master/Wiki_files/connect_to_hosted.png"><figcaption> Connect to a hosted runtime. </figcaption>

In [7]:
# @title
# import wandb
# wandb.login()

# **3. Select your parameters and paths**
---

## **3.1. Choosing parameters**

---

### **Paths to the training data and model**

* <font size = 4>**`training_source`** specifies the paths to the training data. They must be a single multipage TIF file each

* <font size = 4>**`model_path`** specifies the directory where the model checkpoints will be saved.

<font size = 4>**Tip:** To easily copy paths, navigate to the 'Files' tab, right-click on a folder or file, and choose 'Copy path'.

### **Training parameters**

* <font size = 4>**`number_of_epochs`** is the number of times the entire training data will be seen by the model. Default: 50

* <font size = 4>**`batchs_size`** is the number of image that will be bundled together at each training step. Default: 4

* <font size = 4>**`learning_rate`** is the step size of the update of the model's weight. Try decreasing it if the NCuts loss is unstable. Default: 2e-5

* <font size = 4>**`num_classes`** is the number of brightness clusters to segment the image in. Try raising it to 3 if you have artifacts or "halos" around your cells that have significantly different brightness. Default: 2

* <font size = 4>**`weight_decay`** is a regularization parameter used to prevent overfitting. Default: 0.01

* <font size = 4>**`validation_frequency`** is the frequency at which the provided evaluation data is used to estimate the model's performance.

* <font size = 4>**`intensity_sigma`** is the standard deviation of the feature similarity term. Default: 1

* <font size = 4>**`spatial_sigma`** is the standard deviation of the spatial proximity term. Default: 4

* <font size = 4>**`ncuts_radius`** is the radius for the NCuts loss computation, in pixels. Default: 2

* <font size = 4>**`rec_loss`** is the loss to use for the decoder. Can be Mean Square Error (MSE) or Binary Cross Entropy (BCE). Default : MSE

* <font size = 4>**`n_cuts_weight`** is the weight of the NCuts loss in the weighted sum for the backward pass. Default: 0.5
* <font size = 4>**`rec_loss_weight`** is the weight of the reconstruction loss. Default: 0.005


In [5]:
#@markdown ###Path to the training data:
training_source = "./gdrive/MyDrive/CELLSEG_BENCHMARK/DATA/WNET/VIP_full" #@param {type:"string"}
#@markdown ###Model name and path to model folder:
model_path = "./gdrive/MyDrive/CELLSEG_BENCHMARK/WNET_TRAINING_RESULTS" #@param {type:"string"}
#@markdown ---
#@markdown ###Perform validation on a test dataset
do_validation = True #@param {type:"boolean"}
#@markdown ###Path to evaluation data (optional, use if checked above):
eval_source = "./gdrive/MyDrive/CELLSEG_BENCHMARK/DATA/WNET/eval/vol/" #@param {type:"string"}
eval_target = "./gdrive/MyDrive/CELLSEG_BENCHMARK/DATA/WNET/eval/lab/" #@param {type:"string"}
#@markdown ---
#@markdown ###Training parameters
number_of_epochs = 50 #@param {type:"number"}
#@markdown ###Default advanced parameters
use_default_advanced_parameters = False #@param {type:"boolean"}
#@markdown <font size = 4>If not, please change:

#@markdown <font size = 3>Training parameters:
batch_size =  4 #@param {type:"number"}
learning_rate = 2e-5 #@param {type:"number"}
num_classes = 2 #@param {type:"number"}
weight_decay = 0.01 #@param {type:"number"}
#@markdown <font size = 3>Validation parameters:
validation_frequency = 2 #@param {type:"number"}
#@markdown <font size = 3>SoftNCuts parameters:
intensity_sigma = 1.0 #@param {type:"number"}
spatial_sigma = 4.0 #@param {type:"number"}
ncuts_radius = 2 #@param {type:"number"}
#@markdown <font size = 3>Reconstruction loss:
rec_loss = "MSE" #@param["MSE", "BCE"]
#@markdown <font size = 3>Weighted sum of losses:
n_cuts_weight = 0.5 #@param {type:"number"}
rec_loss_weight = 0.005 #@param {type:"number"}

# **4. Train the network**
---

<font size = 4>Important Reminder: Google Colab imposes a maximum session time to prevent extended GPU usage, such as for data mining. Ensure your training duration stays under 12 hours. If your training is projected to exceed this limit, consider reducing the `number_of_epochs`.

## **4.1. Initialize the config**
---

In [7]:
# @title
train_data_folder = Path(training_source)
results_path = Path(model_path)
results_path.mkdir(exist_ok=True)
eval_image_folder = Path(eval_source)
eval_label_folder = Path(eval_target)

eval_dict = c.create_eval_dataset_dict(
        eval_image_folder,
        eval_label_folder,
    ) if do_validation else None

try:
  import wandb
  WANDB_INSTALLED = True
except ImportError:
  WANDB_INSTALLED = False


train_config = WNetTrainingWorkerConfig(
    device="cuda:0",
    max_epochs=number_of_epochs,
    learning_rate=2e-5,
    validation_interval=2,
    batch_size=4,
    num_workers=2,
    weights_info=WeightsInfo(),
    results_path_folder=str(results_path),
    train_data_dict=c.create_dataset_dict_no_labs(train_data_folder),
    eval_volume_dict=eval_dict,
) if use_default_advanced_parameters else WNetTrainingWorkerConfig(
    device="cuda:0",
    max_epochs=number_of_epochs,
    learning_rate=learning_rate,
    validation_interval=validation_frequency,
    batch_size=batch_size,
    num_workers=2,
    weights_info=WeightsInfo(),
    results_path_folder=str(results_path),
    train_data_dict=c.create_dataset_dict_no_labs(train_data_folder),
    eval_volume_dict=eval_dict,
    # advanced
    num_classes=num_classes,
    weight_decay=weight_decay,
    intensity_sigma=intensity_sigma,
    spatial_sigma=spatial_sigma,
    radius=ncuts_radius,
    reconstruction_loss=rec_loss,
    n_cuts_weight=n_cuts_weight,
    rec_loss_weight=rec_loss_weight,
)
wandb_config = WandBConfig(
    mode="disabled" if not WANDB_INSTALLED else "online",
    save_model_artifact=False,
)

INFO:napari_cellseg3d.utils:Images :

INFO:napari_cellseg3d.utils:c3_image_cropped_eval.tif
INFO:napari_cellseg3d.utils:**********
INFO:napari_cellseg3d.utils:Labels :

INFO:napari_cellseg3d.utils:c3_labels_cropped_eval.tif
INFO:napari_cellseg3d.utils:Images :
INFO:napari_cellseg3d.utils:c1_images_cropped_10
INFO:napari_cellseg3d.utils:c1_images_cropped_11
INFO:napari_cellseg3d.utils:c1_images_cropped_12
INFO:napari_cellseg3d.utils:c1_images_cropped_13
INFO:napari_cellseg3d.utils:c1_images_cropped_14
INFO:napari_cellseg3d.utils:c1_images_cropped_15
INFO:napari_cellseg3d.utils:c1_images_cropped_16
INFO:napari_cellseg3d.utils:c1_images_cropped_17
INFO:napari_cellseg3d.utils:c1_images_cropped_18
INFO:napari_cellseg3d.utils:c1_images_cropped_19
INFO:napari_cellseg3d.utils:c1_images_cropped_20
INFO:napari_cellseg3d.utils:c1_images_cropped_21
INFO:napari_cellseg3d.utils:c1_images_cropped_22
INFO:napari_cellseg3d.utils:c1_images_cropped_23
INFO:napari_cellseg3d.utils:c1_images_cropped_3
INFO:

## **4.2. Start training**
---

In [None]:
# @title
worker = c.get_colab_worker(worker_config=train_config, wandb_config=wandb_config)
for epoch_loss in worker.train():
  continue

DEBUG:napari_cellseg3d.utils:wandb config : {'device': 'cuda:0', 'max_epochs': 50, 'learning_rate': 2e-05, 'validation_interval': 2, 'batch_size': 4, 'deterministic_config': DeterministicConfig(enabled=True, seed=34936339), 'scheduler_factor': 0.5, 'scheduler_patience': 10, 'weights_info': WeightsInfo(path='/content/CellSeg3D/napari_cellseg3d/code_models/models/pretrained', custom=False, use_pretrained=False), 'results_path_folder': 'gdrive/MyDrive/CELLSEG_BENCHMARK/WNET_TRAINING_RESULTS', 'sampling': False, 'num_samples': 2, 'sample_size': None, 'do_augmentation': True, 'num_workers': 2, 'train_data_dict': [{'image': 'gdrive/MyDrive/CELLSEG_BENCHMARK/DATA/WNET/VIP_full/c1_images_cropped_10.tif'}, {'image': 'gdrive/MyDrive/CELLSEG_BENCHMARK/DATA/WNET/VIP_full/c1_images_cropped_11.tif'}, {'image': 'gdrive/MyDrive/CELLSEG_BENCHMARK/DATA/WNET/VIP_full/c1_images_cropped_12.tif'}, {'image': 'gdrive/MyDrive/CELLSEG_BENCHMARK/DATA/WNET/VIP_full/c1_images_cropped_13.tif'}, {'image': 'gdrive/My

INFO:napari_cellseg3d.utils:********************
INFO:napari_cellseg3d.utils:-- Parameters --
INFO:napari_cellseg3d.utils:Device: cuda:0
INFO:napari_cellseg3d.utils:Batch size: 4
INFO:napari_cellseg3d.utils:Epochs: 50
INFO:napari_cellseg3d.utils:Learning rate: 2e-05
INFO:napari_cellseg3d.utils:Validation interval: 2
INFO:napari_cellseg3d.utils:Using data augmentation
INFO:napari_cellseg3d.utils:-- Model --
INFO:napari_cellseg3d.utils:Using 2 classes
INFO:napari_cellseg3d.utils:Weight decay: 0.01
INFO:napari_cellseg3d.utils:* NCuts : 
INFO:napari_cellseg3d.utils:- Intensity sigma 1.0
INFO:napari_cellseg3d.utils:- Spatial sigma 4.0
INFO:napari_cellseg3d.utils:- Radius : 2
INFO:napari_cellseg3d.utils:* Reconstruction loss : MSE
INFO:napari_cellseg3d.utils:Weighted sum : 0.5*NCuts + 0.005*Reconstruction
INFO:napari_cellseg3d.utils:-- Data --
INFO:napari_cellseg3d.utils:Training data :

INFO:napari_cellseg3d.utils:gdrive/MyDrive/CELLSEG_BENCHMARK/DATA/WNET/VIP_full/c1_images_cropped_10.tif
