<a href="https://colab.research.google.com/github/AdaptiveMotorControlLab/CellSeg3d/blob/main/notebooks/Colab_WNet3D_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **WNet3D: self-supervised 3D cell segmentation**

---

This notebook is part of the [CellSeg3D project](https://github.com/AdaptiveMotorControlLab/CellSeg3d) in the [Mathis Lab of Adaptive Intelligence](https://www.mackenziemathislab.org/).

- ðŸ’œ The foundation of this notebook owes much to the **[ZeroCostDL4Mic](https://github.com/HenriquesLab/ZeroCostDL4Mic)** project and to the **[DeepLabCut](https://github.com/DeepLabCut/DeepLabCut)** team for bringing Colab into scientific open software.

#**1. Installing dependencies**
---

In [None]:
# #@markdown ##Play to install CellSeg3D and WNet3D dependencies:
# !pip install -q napari-cellseg3d
# print("Dependencies installed")

##**1.2 Load key dependencies**
---

In [None]:
# @title
from pathlib import Path
from napari_cellseg3d.dev_scripts import colab_training as c
from napari_cellseg3d.config import WNetTrainingWorkerConfig, WandBConfig, WeightsInfo, PRETRAINED_WEIGHTS_DIR

## Optional - *1.3 Initialize Weights & Biases integration*
---
If you wish to utilize Weights & Biases (WandB) for monitoring and logging your training session, uncomment and execute the cell below.
To enable it, just input your API key in the space provided.

In [None]:
# !pip install -q wandb
# import wandb
# wandb.login()

# **2. Complete the Colab session**
---



## **2.1. Check for GPU access**
---

By default, this session is configured to use Python 3 and GPU acceleration. To verify or adjust these settings:

<font size = 4>Navigate to Runtime and select Change the Runtime type.

<font size = 4>For Runtime type, ensure it's set to Python 3 (the programming language this program is written in).

<font size = 4>Under Accelerator, choose GPU (Graphics Processing Unit).


In [None]:
#@markdown ##Execute the cell below to verify if GPU access is available.

import torch
if not torch.cuda.is_available():
  print('You do not have GPU access.')
  print('Did you change your runtime?')
  print('If the runtime setting is correct then Google did not allocate a GPU for your session')
  print('Expect slow performance. To access GPU try reconnecting later')

else:
  print('You have GPU access')
  !nvidia-smi


## **2.2. Mount Google Drive**
---
<font size = 4>To integrate this notebook with your personal data, save your data on Google Drive in accordance with the directory structures detailed in Section 0.

1. <font size = 4> **Run** the **cell** below and click on the provided link.

2. <font size = 4>Log in to your Google account and grant the necessary permissions by clicking 'Allow'.

3. <font size = 4>Copy the generated authorization code and paste it into the cell, then press 'Enter'. This grants Colab access to read and write data to your Google Drive.

4. <font size = 4> After completion, you can view your data in the notebook. Simply click the Files tab on the top left and select 'Refresh'.

In [None]:
# # mount user's Google Drive to Google Colab.
# from google.colab import drive
# drive.mount('/content/gdrive')

**<font size = 4> If you cannot see your files, reactivate your session by connecting to your hosted runtime.**


<img width="40%" alt ="Example of image detection with retinanet." src="https://github.com/HenriquesLab/ZeroCostDL4Mic/raw/master/Wiki_files/connect_to_hosted.png"><figcaption> Connect to a hosted runtime. </figcaption>

# **3. Select your parameters and paths**
---

## **3.1. Choosing parameters**

---

### **Paths to the training data and model**

* <font size = 4>**`training_source`** specifies the paths to the training data. They must be a single multipage TIF file each

* <font size = 4>**`model_save_path`** specifies the directory where the model checkpoints will be saved.

<font size = 4>**Tip:** To easily copy paths, navigate to the 'Files' tab, right-click on a folder or file, and choose 'Copy path'.

### **Training parameters**

* <font size = 4>**`number_of_epochs`** is the number of times the entire training data will be seen by the model. Default: 50

* <font size = 4>**`batchs_size`** is the number of image that will be bundled together at each training step. Default: 4

* <font size = 4>**`learning_rate`** is the step size of the update of the model's weight. Try decreasing it if the NCuts loss is unstable. Default: 2e-5

* <font size = 4>**`num_classes`** is the number of brightness clusters to segment the image in. Try raising it to 3 if you have artifacts or "halos" around your cells that have significantly different brightness. Default: 2

* <font size = 4>**`weight_decay`** is a regularization parameter used to prevent overfitting. Default: 0.01

* <font size = 4>**`validation_frequency`** is the frequency at which the provided evaluation data is used to estimate the model's performance.

* <font size = 4>**`intensity_sigma`** is the standard deviation of the feature similarity term. Default: 1

* <font size = 4>**`spatial_sigma`** is the standard deviation of the spatial proximity term. Default: 4

* <font size = 4>**`ncuts_radius`** is the radius for the NCuts loss computation, in pixels. Default: 2

* <font size = 4>**`rec_loss`** is the loss to use for the decoder. Can be Mean Square Error (MSE) or Binary Cross Entropy (BCE). Default : MSE

* <font size = 4>**`n_cuts_weight`** is the weight of the NCuts loss in the weighted sum for the backward pass. Default: 0.5
* <font size = 4>**`rec_loss_weight`** is the weight of the reconstruction loss. Default: 0.005


In [None]:
#@markdown ###Path to the training data:
training_source = "/midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning/train" #@param {type:"string"}
#@markdown ###Path to save the weights (make sure to have enough space in your drive):
model_save_path = "/midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning/checkpoints" #@param {type:"string"}
#@markdown ---
#@markdown ###Perform validation on a test dataset (optional):
do_validation = False #@param {type:"boolean"}
#@markdown ###Path to evaluation data (optional, use if checked above):
eval_source = "/midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning/val/vol" #@param {type:"string"}
eval_target = "/midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning/val/lab" #@param {type:"string"}
#@markdown ---
#@markdown ###Training parameters
number_of_epochs = 50 #@param {type:"number"}
#@markdown ###Default advanced parameters
use_default_advanced_parameters = False #@param {type:"boolean"}
#@markdown <font size = 4>If not, please change:

#@markdown <font size = 3>Training parameters:
batch_size =  4 #@param {type:"number"}
learning_rate = 2e-5 #@param {type:"number"}
num_classes = 2 #@param {type:"number"}
weight_decay = 0.01 #@param {type:"number"}
#@markdown <font size = 3>Validation parameters:
validation_frequency = 2 #@param {type:"number"}
#@markdown <font size = 3>SoftNCuts parameters:
intensity_sigma = 1.0 #@param {type:"number"}
spatial_sigma = 4.0 #@param {type:"number"}
ncuts_radius = 2 #@param {type:"number"}
#@markdown <font size = 3>Reconstruction loss:
rec_loss = "MSE" #@param["MSE", "BCE"]
#@markdown <font size = 3>Weighted sum of losses:
n_cuts_weight = 0.5 #@param {type:"number"}
rec_loss_weight = 0.005 #@param {type:"number"}

# **4. Train the network**
---

<font size = 4>Important Reminder: Google Colab imposes a maximum session time to prevent extended GPU usage, such as for data mining. Ensure your training duration stays under 12 hours. If your training is projected to exceed this limit, consider reducing the `number_of_epochs`.

## **4.1. Initialize the config**
---

In [None]:
# @title
train_data_folder = Path(training_source)
results_path = Path(model_save_path)
results_path.mkdir(exist_ok=True)
eval_image_folder = Path(eval_source)
eval_label_folder = Path(eval_target)

eval_dict = c.create_eval_dataset_dict(
        eval_image_folder,
        eval_label_folder,
    ) if do_validation else None

try:
  import wandb
  WANDB_INSTALLED = True
except ImportError:
  WANDB_INSTALLED = False


train_config = WNetTrainingWorkerConfig(
    device="cuda:0",
    max_epochs=number_of_epochs,
    learning_rate=2e-5,
    validation_interval=2,
    batch_size=4,
    num_workers=2,
    weights_info=WeightsInfo(),
    results_path_folder=str(results_path),
    train_data_dict=c.create_dataset_dict_no_labs(train_data_folder),
    eval_volume_dict=eval_dict,
) if use_default_advanced_parameters else WNetTrainingWorkerConfig(
    device="cuda:0",
    max_epochs=number_of_epochs,
    learning_rate=learning_rate,
    validation_interval=validation_frequency,
    batch_size=batch_size,
    num_workers=2,
    weights_info=WeightsInfo(),
    results_path_folder=str(results_path),
    train_data_dict=c.create_dataset_dict_no_labs(train_data_folder),
    eval_volume_dict=eval_dict,
    # advanced
    num_classes=num_classes,
    weight_decay=weight_decay,
    intensity_sigma=intensity_sigma,
    spatial_sigma=spatial_sigma,
    radius=ncuts_radius,
    reconstruction_loss=rec_loss,
    n_cuts_weight=n_cuts_weight,
    rec_loss_weight=rec_loss_weight,
)
wandb_config = WandBConfig(
    mode="disabled" if not WANDB_INSTALLED else "online",
    save_model_artifact=False,
)

## **4.2. Start training**
---

In [None]:
# @title
worker = c.get_colab_worker(worker_config=train_config, wandb_config=wandb_config)
for epoch_loss in worker.train():
  continue

In [None]:
# Once you have trained the model, you will have the weights as a .pth file

# Adapt to Selma data

In [1]:
from pathlib import Path
import numpy as np
import nibabel as nib
import tifffile as tiff
import csv
import random

# ----------------------------
# Paths
# ----------------------------
nii_root = Path("/midtier/paetzollab/scratch/ads4015/data_selma3d/selma3d_finetune_patches/cell_nucleus_patches")
base_out = Path("/midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning")
tif_root = base_out / "tif_data"

tif_train_vol = tif_root / "train" / "vol"
tif_train_lab = tif_root / "train" / "lab"
tif_val_vol   = tif_root / "val"   / "vol"
tif_val_lab   = tif_root / "val"   / "lab"
tif_test_vol  = tif_root / "test"  / "vol"
tif_test_lab  = tif_root / "test"  / "lab"

for p in [tif_train_vol, tif_train_lab, tif_val_vol, tif_val_lab, tif_test_vol, tif_test_lab]:
    p.mkdir(parents=True, exist_ok=True)

# ----------------------------
# Collect all image files
# ----------------------------
all_imgs = sorted(
    f for f in nii_root.glob("*.nii.gz")
    if f.name.endswith("_ch0.nii.gz") and "_label" not in f.name
)

print(f"Found {len(all_imgs)} image patches")

if len(all_imgs) < 4:
    raise RuntimeError("Need at least 4 patches for 2 test + train/val split.")

# ----------------------------
# Deterministic split
# ----------------------------
seed = 123
random.Random(seed).shuffle(all_imgs)

n_total = len(all_imgs)
n_test = 2
n_trainval = n_total - n_test
n_val = max(1, int(round(0.2 * n_trainval)))
n_train = n_trainval - n_val

test_imgs  = all_imgs[:n_test]
val_imgs   = all_imgs[n_test:n_test + n_val]
train_imgs = all_imgs[n_test + n_val:]

print(f"Train: {len(train_imgs)}, Val: {len(val_imgs)}, Test: {len(test_imgs)}")

# ----------------------------
# Helper: NIfTI -> np array
# ----------------------------
def load_nifti(path: Path):
    nii = nib.load(str(path))
    arr = nii.get_fdata()
    arr = np.squeeze(arr)  # drop singleton channel if present

    # NIfTI is typically (X, Y, Z); for TIFF stacks we want (Z, Y, X)
    if arr.ndim != 3:
        raise ValueError(f"Expected 3D volume for {path}, got shape {arr.shape}")

    # reorder axes: (X, Y, Z) -> (Z, Y, X)
    arr = np.transpose(arr, (2, 1, 0))

    return arr


def img_to_label_path(img_path: Path) -> Path:
    # patch_000_vol003_ch0.nii.gz -> patch_000_vol003_ch0_label.nii.gz
    return img_path.with_name(img_path.name.replace(".nii.gz", "_label.nii.gz"))

# ----------------------------
# Convert and save
# ----------------------------
rows = []  # to record splits

def process_split(split_name, img_paths, vol_dir, lab_dir):
    for img_path in img_paths:
        lab_path = img_to_label_path(img_path)
        if not lab_path.exists():
            raise FileNotFoundError(f"Missing label for {img_path}: {lab_path}")

        img = load_nifti(img_path)
        lab = load_nifti(lab_path)

        # cast image to float32 so ITK can read it (no 64-bit samples)
        img = img.astype(np.float32)

        # cast label to uint8 (0/1 or small integers)
        lab = lab.astype(np.uint8)

        # properly strip ".nii.gz"
        base_name = img_path.name.replace(".nii.gz", "")  # -> "patch_004_vol004_ch0"

        out_img_path = vol_dir / f"{base_name}.tif"
        out_lab_path = lab_dir / f"{base_name}_label.tif"

        tiff.imwrite(out_img_path, img)
        tiff.imwrite(out_lab_path, lab)


        rows.append({
            "split": split_name,
            "nii_image": str(img_path),
            "nii_label": str(lab_path),
            "tif_image": str(out_img_path),
            "tif_label": str(out_lab_path),
        })

process_split("train", train_imgs, tif_train_vol, tif_train_lab)
process_split("val",   val_imgs,   tif_val_vol,   tif_val_lab)
process_split("test",  test_imgs,  tif_test_vol,  tif_test_lab)

# ----------------------------
# Save split info
# ----------------------------
split_csv = base_out / "splits_cell_nucleus.csv"
with open(split_csv, "w", newline="") as f:
    writer = csv.DictWriter(f, fieldnames=["split", "nii_image", "nii_label", "tif_image", "tif_label"])
    writer.writeheader()
    writer.writerows(rows)

print(f"Saved split info to {split_csv}")
print("Done creating TIF dataset.")


Found 25 image patches
Train: 18, Val: 5, Test: 2
Saved split info to /midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning/splits_cell_nucleus.csv
Done creating TIF dataset.


In [2]:
from pathlib import Path
from napari_cellseg3d.dev_scripts import colab_training as c
from napari_cellseg3d.config import WNetTrainingWorkerConfig, WandBConfig, WeightsInfo

# ------------------------------------------------------------------
# Use the TIF dataset we just created
# ------------------------------------------------------------------
# training_source is the folder with multipage TIF volumes (no labels)
training_source = "/midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning/tif_data/train/vol"

# model_save_path: where to store training results and checkpoints
model_save_path = "/midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning/checkpoints"

# We'll use the val split for in-training validation
do_validation = True
eval_source = "/midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning/tif_data/val/vol"
eval_target = "/midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning/tif_data/val/lab"

# ------------------------------------------------------------------
# Training hyperparameters
# ------------------------------------------------------------------
number_of_epochs = 50

use_default_advanced_parameters = False  # we'll set explicitly

batch_size = 4
learning_rate = 2e-5
num_classes = 2            # nuclei vs background
weight_decay = 0.01

validation_frequency = 2   # validate every 2 epochs

# SoftNCuts parameters
intensity_sigma = 1.0
spatial_sigma = 4.0
ncuts_radius = 2

# Reconstruction loss
rec_loss = "MSE"           # or "BCE"

# Loss mixing
n_cuts_weight = 0.5
rec_loss_weight = 0.005


In [3]:
from pathlib import Path

train_data_folder = Path(training_source)
results_path = Path(model_save_path)
results_path.mkdir(parents=True, exist_ok=True)

eval_image_folder = Path(eval_source)
eval_label_folder = Path(eval_target)

eval_dict = c.create_eval_dataset_dict(
    eval_image_folder,
    eval_label_folder,
) if do_validation else None

try:
    import wandb
    WANDB_INSTALLED = True
except ImportError:
    WANDB_INSTALLED = False

# -----------------------------
# Build training config
# -----------------------------
if use_default_advanced_parameters:
    train_config = WNetTrainingWorkerConfig(
        device="cuda:0",
        max_epochs=number_of_epochs,
        learning_rate=2e-5,
        validation_interval=2,
        batch_size=4,
        num_workers=2,
        weights_info=WeightsInfo(),
        results_path_folder=str(results_path),
        train_data_dict=c.create_dataset_dict_no_labs(train_data_folder),
        eval_volume_dict=eval_dict,
    )
else:
    train_config = WNetTrainingWorkerConfig(
        device="cuda:0",
        max_epochs=number_of_epochs,
        learning_rate=learning_rate,
        validation_interval=validation_frequency,
        batch_size=batch_size,
        num_workers=2,
        weights_info=WeightsInfo(),
        results_path_folder=str(results_path),
        train_data_dict=c.create_dataset_dict_no_labs(train_data_folder),
        eval_volume_dict=eval_dict,
        # advanced parameters:
        num_classes=num_classes,
        weight_decay=weight_decay,
        intensity_sigma=intensity_sigma,
        spatial_sigma=spatial_sigma,
        radius=ncuts_radius,
        reconstruction_loss=rec_loss,
        n_cuts_weight=n_cuts_weight,
        rec_loss_weight=rec_loss_weight,
    )

wandb_config = WandBConfig(
    mode="disabled" if not WANDB_INSTALLED else "online",
    save_model_artifact=False,
)

# -----------------------------
# Start training
# -----------------------------
worker = c.get_colab_worker(worker_config=train_config, wandb_config=wandb_config)

for epoch_loss in worker.train():
    # you can print / log epoch_loss here if you want
    pass

print("Training finished.")


monai.transforms.spatial.dictionary Orientationd.__init__:labels: Current default value of argument `labels=(('L', 'R'), ('P', 'A'), ('I', 'S'))` was changed in version None from `labels=(('L', 'R'), ('P', 'A'), ('I', 'S'))` to `labels=None`. Default value changed to None meaning that the transform now uses the 'space' of a meta-tensor, if applicable, to determine appropriate axis labels.
Loading dataset: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 18/18 [00:00<00:00, 279620.27it/s]
Loading dataset:   0%|          | 0/5 [00:00<?, ?it/s]`data_array` is not of type `MetaTensor, assuming affine to be identity.
Loading dataset: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5/5 [00:00<00:00, 37.59it/s]
max_pool3d_with_indices_backward_cuda does not have a deterministic implementation, but you set 'torch.use_deterministic_algorithms(True, warn_only=True)'. You can file an issue at https://github.com/pytorch/pytorch/issues to help us prioritize adding deterministic support for this operation. (Triggered inte

Training finished.


### Inference using finetuned model

In [7]:
import logging
from pathlib import Path

import numpy as np
from tifffile import imread, imwrite

import torch
from napari_cellseg3d.dev_scripts import remote_inference as cs3d
from napari_cellseg3d.utils import LOGGER as logger
from napari_cellseg3d.config import ModelInfo

In [8]:
from napari_cellseg3d.config import WeightsInfo
import inspect

print(WeightsInfo)
help(WeightsInfo)          # see its fields / constructor args

print("CONFIG:", cs3d.CONFIG)
print("CONFIG fields:", dir(cs3d.CONFIG))


<class 'napari_cellseg3d.config.WeightsInfo'>
Help on class WeightsInfo in module napari_cellseg3d.config:

class WeightsInfo(builtins.object)
 |  WeightsInfo(path: Optional[str] = '/home/ads4015/micromamba/envs/cellseg3d-env1/lib/python3.10/site-packages/napari_cellseg3d/code_models/models/pretrained', use_pretrained: Optional[bool] = False, use_custom: Optional[bool] = False) -> None
 |  
 |  Class to record params for weights.
 |  
 |  Args:
 |      path (Optional[str]): path to weights
 |      use_custom (Optional[bool]): whether to use custom weights
 |      use_pretrained (Optional[bool]): whether to use pretrained weights
 |  
 |  Methods defined here:
 |  
 |  __eq__(self, other)
 |      Return self==value.
 |  
 |  __init__(self, path: Optional[str] = '/home/ads4015/micromamba/envs/cellseg3d-env1/lib/python3.10/site-packages/napari_cellseg3d/code_models/models/pretrained', use_pretrained: Optional[bool] = False, use_custom: Optional[bool] = False) -> None
 |      Initialize se

In [10]:
import torch
from collections import OrderedDict

src = "/midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning/checkpoints/wnet_best_metric.pth"
dst = "/midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning/checkpoints/wnet_best_metric_for_inference.pth"

raw = torch.load(src, map_location="cpu")

# If keys have "module.", strip them
new_state = OrderedDict()
for k, v in raw.items():
    new_state[k.replace("module.", "")] = v

wrapped = {"state_dict": new_state}

torch.save(wrapped, dst)
print("Saved wrapped checkpoint to:", dst)


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.


Saved wrapped checkpoint to: /midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning/checkpoints/wnet_best_metric_for_inference.pth


In [11]:
# ------------------------------------------------------------
# 0. Imports
# ------------------------------------------------------------
import logging
from copy import deepcopy
from pathlib import Path

import numpy as np
from tifffile import imread, imwrite

import torch
import napari_cellseg3d
from napari_cellseg3d.dev_scripts import remote_inference as cs3d
from napari_cellseg3d.utils import LOGGER as logger
from napari_cellseg3d.config import ModelInfo, WeightsInfo

# Make CellSeg3D log info-level messages
logger.setLevel(logging.INFO)

print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    import pyclesperanto as cle
    cle.select_device()
    print("Selected OpenCL device:", cle.get_device())

# ------------------------------------------------------------
# 1. Paths
# ------------------------------------------------------------
test_vol_dir = Path(
    "/midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning/tif_data/test/vol"
)
test_lab_dir = Path(
    "/midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning/tif_data/test/lab"
)

pred_root = Path(
    "/midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning/preds/test"
)
pred_root.mkdir(parents=True, exist_ok=True)

ckpt_path = Path(
    "/midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning/checkpoints/wnet_best_metric_for_inference.pth"
)
print("Using custom weights:", ckpt_path)
assert ckpt_path.exists(), "Checkpoint path does not exist!"

print("Test vols:", [p.name for p in sorted(test_vol_dir.glob('*.tif'))])
print("Test labs:", [p.name for p in sorted(test_lab_dir.glob('*.tif'))])
print("Saving preds under:", pred_root)

# ------------------------------------------------------------
# 2. Build inference + post-processing config
# ------------------------------------------------------------
# Start from the default CONFIG but don't mutate the global one
inference_config = deepcopy(cs3d.CONFIG)

model_selection = "WNet3D"
print(f"Selected model: {model_selection}")

# Tell CellSeg3D which model and patch size to use
inference_config.model_info = ModelInfo(
    name=model_selection,
    model_input_size=[64, 64, 64],  # same as your working example
    num_classes=2,                  # nuclei vs background
)

# Point weights_config to your finetuned checkpoint
inference_config.weights_config = WeightsInfo(
    path=str(ckpt_path),
    use_pretrained=False,
    use_custom=True,  # <- IMPORTANT: tell it to use your custom weights
)

# Post-processing config (thresholding + instance seg)
post_process_config = cs3d.PostProcessConfig()

# ------------------------------------------------------------
# 3. Helper: run inference on one volume
# ------------------------------------------------------------
def run_inference_on_volume(vol_zyx: np.ndarray):
    """
    vol_zyx: np.ndarray with shape (Z, Y, X), float32, ideally in [0, 1].
    Returns:
        semantic_zyx (uint8)
        instance_zyx (uint16)
    """
    # Run the model
    result_list = cs3d.inference_on_images(vol_zyx, config=inference_config)
    result = result_list[0]

    # For WNet3D, semantic output is often (C, Z, Y, X); channel 1 = nuclei
    semantic = result.semantic_segmentation
    if model_selection == "WNet3D" and semantic.ndim == 4:
        semantic = semantic[1]  # select nuclei channel -> (Z, Y, X)

    # Post-process to instance segmentation
    instance_seg, _stats = cs3d.post_processing(
        semantic,
        config=post_process_config,
    )

    semantic = semantic.astype(np.uint8)
    instance_seg = instance_seg.astype(np.uint16)
    return semantic, instance_seg

# ------------------------------------------------------------
# 4. Loop over test volumes and save predictions
# ------------------------------------------------------------
for vol_path in sorted(test_vol_dir.glob("*.tif")):
    print(f"\n=== Processing {vol_path.name} ===")

    # Load volume as (Z, Y, X)
    vol = imread(str(vol_path)).astype(np.float32)

    # Match your earlier normalization: scale to [0, 1]
    vmin, vmax = vol.min(), vol.max()
    if vmax > vmin:
        vol_norm = (vol - vmin) / (vmax - vmin)
    else:
        vol_norm = vol

    # Run inference
    semantic_zyx, instance_zyx = run_inference_on_volume(vol_norm)

    # Output filenames
    base = vol_path.stem  # e.g. "patch_004_vol004_ch0"
    out_sem = pred_root / f"{base}_cellseg3d_wnet_semantic.tif"
    out_inst = pred_root / f"{base}_cellseg3d_wnet_instances.tif"

    imwrite(out_sem, semantic_zyx)
    imwrite(out_inst, instance_zyx)

    print("  Saved semantic to: ", out_sem)
    print("  Saved instances to:", out_inst)

print("\nDone! Predictions are in:", pred_root)


CUDA available: True
Selected OpenCL device: 0x5636d6a89110
Using custom weights: /midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning/checkpoints/wnet_best_metric_for_inference.pth
Test vols: ['patch_004_vol004_ch0.tif', 'patch_006_vol006_ch0.tif']
Test labs: ['patch_004_vol004_ch0_label.tif', 'patch_006_vol006_ch0_label.tif']
Saving preds under: /midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning/preds/test
Selected model: WNet3D

=== Processing patch_004_vol004_ch0.tif ===
--------------------
Parameters summary :
Model is : WNet3D
Window inference is enabled
Window size is 64
Window overlap is 0.25
Dataset loaded on cuda device
--------------------
MODEL DIMS : [64, 64, 64]
Model name : WNet3D
Instantiating model...
Loading weights...
Weights status : None
Done
--------------------
Parameters summary :
Model is : WNet3D
Window inference is enabled
Window size is 64
Window overlap is 0.25
Dataset loaded on cuda device
--------------------
Loa

INFO:napari_cellseg3d.dev_scripts.remote_inference:Thresholding with 0.4
INFO:napari_cellseg3d.dev_scripts.remote_inference:Clearing large objects with 500


Post-processing...
Layer prediction saved as : volume_WNet3D_pred_1_2025_11_26_19_41_39


0it [00:00, ?it/s]Only one label was provided to `remove_small_objects`. Did you mean to use a boolean array?
1it [00:00,  1.16it/s]
INFO:napari_cellseg3d.dev_scripts.remote_inference:Running instance segmentation with 0.55 and 0.55
invalid value encountered in divide
INFO:napari_cellseg3d.dev_scripts.remote_inference:Clearing small objects with 5


size: 96
  Saved semantic to:  /midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning/preds/test/patch_004_vol004_ch0_cellseg3d_wnet_semantic.tif
  Saved instances to: /midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning/preds/test/patch_004_vol004_ch0_cellseg3d_wnet_instances.tif

=== Processing patch_006_vol006_ch0.tif ===
--------------------
Parameters summary :
Model is : WNet3D
Window inference is enabled
Window size is 64
Window overlap is 0.25
Dataset loaded on cuda device
--------------------
MODEL DIMS : [64, 64, 64]
Model name : WNet3D
Instantiating model...
Loading weights...
Weights status : None
Done
--------------------
Parameters summary :
Model is : WNet3D
Window inference is enabled
Window size is 64
Window overlap is 0.25
Dataset loaded on cuda device
--------------------
Loading layer
2025-11-26 19:41:41,384 - INFO - Apply pending transforms - lazy: False, pending: 0, upcoming 'function', transform is not lazy
2025-11-26 19:41:4

INFO:napari_cellseg3d.dev_scripts.remote_inference:Thresholding with 0.4
INFO:napari_cellseg3d.dev_scripts.remote_inference:Clearing large objects with 500


Post-processing...
Layer prediction saved as : volume_WNet3D_pred_1_2025_11_26_19_41_41


1it [00:00, 10.69it/s]
INFO:napari_cellseg3d.dev_scripts.remote_inference:Running instance segmentation with 0.55 and 0.55
INFO:napari_cellseg3d.dev_scripts.remote_inference:Clearing small objects with 5


size: 96


divide by zero encountered in scalar divide
invalid value encountered in scalar multiply


  Saved semantic to:  /midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning/preds/test/patch_006_vol006_ch0_cellseg3d_wnet_semantic.tif
  Saved instances to: /midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning/preds/test/patch_006_vol006_ch0_cellseg3d_wnet_instances.tif

Done! Predictions are in: /midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning/preds/test
