<a href="https://colab.research.google.com/github/AdaptiveMotorControlLab/CellSeg3d/blob/main/notebooks/Colab_WNet3D_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **WNet3D: self-supervised 3D cell segmentation**

---

This notebook is part of the [CellSeg3D project](https://github.com/AdaptiveMotorControlLab/CellSeg3d) in the [Mathis Lab of Adaptive Intelligence](https://www.mackenziemathislab.org/).

- ðŸ’œ The foundation of this notebook owes much to the **[ZeroCostDL4Mic](https://github.com/HenriquesLab/ZeroCostDL4Mic)** project and to the **[DeepLabCut](https://github.com/DeepLabCut/DeepLabCut)** team for bringing Colab into scientific open software.

#**1. Installing dependencies**
---

In [None]:
# #@markdown ##Play to install CellSeg3D and WNet3D dependencies:
# !pip install -q napari-cellseg3d
# print("Dependencies installed")

##**1.2 Load key dependencies**
---

In [None]:
# @title
from pathlib import Path
from napari_cellseg3d.dev_scripts import colab_training as c
from napari_cellseg3d.config import WNetTrainingWorkerConfig, WandBConfig, WeightsInfo, PRETRAINED_WEIGHTS_DIR

## Optional - *1.3 Initialize Weights & Biases integration*
---
If you wish to utilize Weights & Biases (WandB) for monitoring and logging your training session, uncomment and execute the cell below.
To enable it, just input your API key in the space provided.

In [None]:
# !pip install -q wandb
# import wandb
# wandb.login()

# **2. Complete the Colab session**
---



## **2.1. Check for GPU access**
---

By default, this session is configured to use Python 3 and GPU acceleration. To verify or adjust these settings:

<font size = 4>Navigate to Runtime and select Change the Runtime type.

<font size = 4>For Runtime type, ensure it's set to Python 3 (the programming language this program is written in).

<font size = 4>Under Accelerator, choose GPU (Graphics Processing Unit).


In [None]:
#@markdown ##Execute the cell below to verify if GPU access is available.

import torch
if not torch.cuda.is_available():
  print('You do not have GPU access.')
  print('Did you change your runtime?')
  print('If the runtime setting is correct then Google did not allocate a GPU for your session')
  print('Expect slow performance. To access GPU try reconnecting later')

else:
  print('You have GPU access')
  !nvidia-smi


## **2.2. Mount Google Drive**
---
<font size = 4>To integrate this notebook with your personal data, save your data on Google Drive in accordance with the directory structures detailed in Section 0.

1. <font size = 4> **Run** the **cell** below and click on the provided link.

2. <font size = 4>Log in to your Google account and grant the necessary permissions by clicking 'Allow'.

3. <font size = 4>Copy the generated authorization code and paste it into the cell, then press 'Enter'. This grants Colab access to read and write data to your Google Drive.

4. <font size = 4> After completion, you can view your data in the notebook. Simply click the Files tab on the top left and select 'Refresh'.

In [None]:
# # mount user's Google Drive to Google Colab.
# from google.colab import drive
# drive.mount('/content/gdrive')

**<font size = 4> If you cannot see your files, reactivate your session by connecting to your hosted runtime.**


<img width="40%" alt ="Example of image detection with retinanet." src="https://github.com/HenriquesLab/ZeroCostDL4Mic/raw/master/Wiki_files/connect_to_hosted.png"><figcaption> Connect to a hosted runtime. </figcaption>

# **3. Select your parameters and paths**
---

## **3.1. Choosing parameters**

---

### **Paths to the training data and model**

* <font size = 4>**`training_source`** specifies the paths to the training data. They must be a single multipage TIF file each

* <font size = 4>**`model_save_path`** specifies the directory where the model checkpoints will be saved.

<font size = 4>**Tip:** To easily copy paths, navigate to the 'Files' tab, right-click on a folder or file, and choose 'Copy path'.

### **Training parameters**

* <font size = 4>**`number_of_epochs`** is the number of times the entire training data will be seen by the model. Default: 50

* <font size = 4>**`batchs_size`** is the number of image that will be bundled together at each training step. Default: 4

* <font size = 4>**`learning_rate`** is the step size of the update of the model's weight. Try decreasing it if the NCuts loss is unstable. Default: 2e-5

* <font size = 4>**`num_classes`** is the number of brightness clusters to segment the image in. Try raising it to 3 if you have artifacts or "halos" around your cells that have significantly different brightness. Default: 2

* <font size = 4>**`weight_decay`** is a regularization parameter used to prevent overfitting. Default: 0.01

* <font size = 4>**`validation_frequency`** is the frequency at which the provided evaluation data is used to estimate the model's performance.

* <font size = 4>**`intensity_sigma`** is the standard deviation of the feature similarity term. Default: 1

* <font size = 4>**`spatial_sigma`** is the standard deviation of the spatial proximity term. Default: 4

* <font size = 4>**`ncuts_radius`** is the radius for the NCuts loss computation, in pixels. Default: 2

* <font size = 4>**`rec_loss`** is the loss to use for the decoder. Can be Mean Square Error (MSE) or Binary Cross Entropy (BCE). Default : MSE

* <font size = 4>**`n_cuts_weight`** is the weight of the NCuts loss in the weighted sum for the backward pass. Default: 0.5
* <font size = 4>**`rec_loss_weight`** is the weight of the reconstruction loss. Default: 0.005


In [None]:
#@markdown ###Path to the training data:
training_source = "/midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning/train" #@param {type:"string"}
#@markdown ###Path to save the weights (make sure to have enough space in your drive):
model_save_path = "/midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning/checkpoints" #@param {type:"string"}
#@markdown ---
#@markdown ###Perform validation on a test dataset (optional):
do_validation = False #@param {type:"boolean"}
#@markdown ###Path to evaluation data (optional, use if checked above):
eval_source = "/midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning/val/vol" #@param {type:"string"}
eval_target = "/midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning/val/lab" #@param {type:"string"}
#@markdown ---
#@markdown ###Training parameters
number_of_epochs = 50 #@param {type:"number"}
#@markdown ###Default advanced parameters
use_default_advanced_parameters = False #@param {type:"boolean"}
#@markdown <font size = 4>If not, please change:

#@markdown <font size = 3>Training parameters:
batch_size =  4 #@param {type:"number"}
learning_rate = 2e-5 #@param {type:"number"}
num_classes = 2 #@param {type:"number"}
weight_decay = 0.01 #@param {type:"number"}
#@markdown <font size = 3>Validation parameters:
validation_frequency = 2 #@param {type:"number"}
#@markdown <font size = 3>SoftNCuts parameters:
intensity_sigma = 1.0 #@param {type:"number"}
spatial_sigma = 4.0 #@param {type:"number"}
ncuts_radius = 2 #@param {type:"number"}
#@markdown <font size = 3>Reconstruction loss:
rec_loss = "MSE" #@param["MSE", "BCE"]
#@markdown <font size = 3>Weighted sum of losses:
n_cuts_weight = 0.5 #@param {type:"number"}
rec_loss_weight = 0.005 #@param {type:"number"}

# **4. Train the network**
---

<font size = 4>Important Reminder: Google Colab imposes a maximum session time to prevent extended GPU usage, such as for data mining. Ensure your training duration stays under 12 hours. If your training is projected to exceed this limit, consider reducing the `number_of_epochs`.

## **4.1. Initialize the config**
---

In [None]:
# @title
train_data_folder = Path(training_source)
results_path = Path(model_save_path)
results_path.mkdir(exist_ok=True)
eval_image_folder = Path(eval_source)
eval_label_folder = Path(eval_target)

eval_dict = c.create_eval_dataset_dict(
        eval_image_folder,
        eval_label_folder,
    ) if do_validation else None

try:
  import wandb
  WANDB_INSTALLED = True
except ImportError:
  WANDB_INSTALLED = False


train_config = WNetTrainingWorkerConfig(
    device="cuda:0",
    max_epochs=number_of_epochs,
    learning_rate=2e-5,
    validation_interval=2,
    batch_size=4,
    num_workers=2,
    weights_info=WeightsInfo(),
    results_path_folder=str(results_path),
    train_data_dict=c.create_dataset_dict_no_labs(train_data_folder),
    eval_volume_dict=eval_dict,
) if use_default_advanced_parameters else WNetTrainingWorkerConfig(
    device="cuda:0",
    max_epochs=number_of_epochs,
    learning_rate=learning_rate,
    validation_interval=validation_frequency,
    batch_size=batch_size,
    num_workers=2,
    weights_info=WeightsInfo(),
    results_path_folder=str(results_path),
    train_data_dict=c.create_dataset_dict_no_labs(train_data_folder),
    eval_volume_dict=eval_dict,
    # advanced
    num_classes=num_classes,
    weight_decay=weight_decay,
    intensity_sigma=intensity_sigma,
    spatial_sigma=spatial_sigma,
    radius=ncuts_radius,
    reconstruction_loss=rec_loss,
    n_cuts_weight=n_cuts_weight,
    rec_loss_weight=rec_loss_weight,
)
wandb_config = WandBConfig(
    mode="disabled" if not WANDB_INSTALLED else "online",
    save_model_artifact=False,
)

## **4.2. Start training**
---

In [None]:
# @title
worker = c.get_colab_worker(worker_config=train_config, wandb_config=wandb_config)
for epoch_loss in worker.train():
  continue

In [None]:
# Once you have trained the model, you will have the weights as a .pth file

## Adapt for Selma data

## Create tiff dataset

In [1]:
from pathlib import Path
import numpy as np
import nibabel as nib
import tifffile as tiff
import csv
import random

# ----------------------------
# Paths
# ----------------------------
nii_root = Path("/midtier/paetzollab/scratch/ads4015/data_selma3d/selma3d_finetune_patches/cell_nucleus_patches")
base_out = Path("/midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning")
tif_root = base_out / "tif_data"

splits = {
    "train": (tif_root / "train" / "vol", tif_root / "train" / "lab"),
    "val":   (tif_root / "val" / "vol",   tif_root / "val" / "lab"),
    "test":  (tif_root / "test" / "vol",  tif_root / "test" / "lab"),
}

# Make dirs
for vol_dir, lab_dir in splits.values():
    vol_dir.mkdir(parents=True, exist_ok=True)
    lab_dir.mkdir(parents=True, exist_ok=True)

# ----------------------------
# Collect image files
# ----------------------------
all_imgs = sorted(
    f for f in nii_root.glob("*.nii.gz")
    if f.name.endswith("_ch0.nii.gz") and "_label" not in f.name
)

print(f"Found {len(all_imgs)} patches")
if len(all_imgs) < 4:
    raise RuntimeError("Need at least 4 patches.")

# ----------------------------
# Deterministic split
# ----------------------------
random.Random(123).shuffle(all_imgs)

n_test = 2
n_val = max(1, int(0.2 * (len(all_imgs)-n_test)))

test_imgs  = all_imgs[:n_test]
val_imgs   = all_imgs[n_test:n_test+n_val]
train_imgs = all_imgs[n_test+n_val:]

print(f"Train={len(train_imgs)}  Val={len(val_imgs)}  Test={len(test_imgs)}")

# ----------------------------
# Helpers
# ----------------------------
def load_nifti(path: Path):
    arr = nib.load(str(path)).get_fdata()
    arr = np.squeeze(arr)
    if arr.ndim != 3:
        raise ValueError(f"Expected 3D but got {arr.shape}")
    return np.transpose(arr, (2, 1, 0))  # XYZ â†’ ZYX

def label_path(img: Path):
    return img.with_name(img.name.replace(".nii.gz", "_label.nii.gz"))

# ----------------------------
# Convert and save
# ----------------------------
rows = []
def convert_split(name, imgs):
    vol_dir, lab_dir = splits[name]

    for img_path in imgs:
        lab_path = label_path(img_path)
        if not lab_path.exists():
            raise FileNotFoundError(f"Missing label: {lab_path}")

        img = load_nifti(img_path).astype(np.float32)
        lab = load_nifti(lab_path).astype(np.uint8)

        base = img_path.name.removesuffix(".nii.gz")
        out_img = vol_dir / f"{base}.tif"
        out_lab = lab_dir / f"{base}_label.tif"

        tiff.imwrite(out_img, img)
        tiff.imwrite(out_lab, lab)

        rows.append({
            "split": name,
            "nii_image": str(img_path),
            "nii_label": str(lab_path),
            "tif_image": str(out_img),
            "tif_label": str(out_lab),
        })

convert_split("train", train_imgs)
convert_split("val",   val_imgs)
convert_split("test",  test_imgs)

# Save CSV
csv_path = base_out / "splits_cell_nucleus.csv"
with open(csv_path, "w", newline="") as f:
    writer = csv.DictWriter(f, rows[0].keys())
    writer.writeheader()
    writer.writerows(rows)

print("Saved split info â†’", csv_path)
print("TIF dataset ready.")


Found 25 patches
Train=19  Val=4  Test=2
Saved split info â†’ /midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning/splits_cell_nucleus.csv
TIF dataset ready.


## Train WNet

In [2]:
from pathlib import Path
from napari_cellseg3d.dev_scripts import colab_training as c
from napari_cellseg3d.config import WNetTrainingWorkerConfig, WandBConfig, WeightsInfo
from tqdm.auto import tqdm

# ----------------------------
# Training paths
# ----------------------------
train_vol = tif_root / "train" / "vol"
val_vol   = tif_root / "val" / "vol"
val_lab   = tif_root / "val" / "lab"
save_dir  = base_out / "checkpoints"
save_dir.mkdir(exist_ok=True, parents=True)

# ----------------------------
# Training hyperparameters
# ----------------------------
num_epochs = 50
validation_freq = 2

batch_size = 4
learning_rate = 2e-5
num_classes = 2
weight_decay = 0.01

# Tell CellSeg3D to start from the built-in WNet pretrained weights
pretrained_wnet_weights = WeightsInfo(
    use_pretrained=True,   # <- start from pretrained WNet
    use_custom=False,      # <- not using a custom checkpoint for training
    path=None,             # or "" â€“ path is only needed for custom weights
)


# Tell CellSeg3D to start from the built-in WNet pretrained weights
pretrained_wnet_weights = WeightsInfo(
    use_pretrained=True,
    use_custom=False,
    path=None,
)

train_config = WNetTrainingWorkerConfig(
    device="cuda:0",
    max_epochs=num_epochs,
    learning_rate=learning_rate,
    validation_interval=validation_freq,
    batch_size=batch_size,
    num_workers=2,
    weights_info=pretrained_wnet_weights,   # <- using WNet pretrained weights
    results_path_folder=str(save_dir),
    train_data_dict=c.create_dataset_dict_no_labs(train_vol),
    eval_volume_dict=c.create_eval_dataset_dict(val_vol, val_lab),

    # Advanced parameters:
    num_classes=num_classes,
    weight_decay=weight_decay,
    intensity_sigma=1.0,
    spatial_sigma=4.0,
    radius=2,
    reconstruction_loss="MSE",
    n_cuts_weight=0.5,
    rec_loss_weight=0.005,
)



twcfg = train_config
winfo = twcfg.weights_info

if getattr(winfo, "use_custom", False) and getattr(winfo, "path", None):
    print(f"[Train] Using custom weights from: {winfo.path}")
elif getattr(winfo, "use_pretrained", False):
    print("[Train] Starting from pretrained WNet weights.")
else:
    print("[Train] Starting from random init weights.")



wandb_config = WandBConfig(mode="disabled")

worker = c.get_colab_worker(worker_config=train_config, wandb_config=wandb_config)

# ----------------------------
# Training loop with live losses
# ----------------------------
training_gen = worker.train()

pbar = tqdm(total=num_epochs, desc="Training", unit="epoch")
prev_epochs = 0

for _ in training_gen:
    epochs_done = len(worker.total_losses)

    if epochs_done > prev_epochs:
        prev_epochs = epochs_done
        idx = epochs_done

        train_loss = worker.total_losses[-1]
        rec_loss   = worker.rec_losses[-1]
        ncuts_loss = worker.ncuts_losses[-1]

        postfix = {
            "train": f"{train_loss:.4f}",
            "rec":   f"{rec_loss:.4f}",
            "ncuts": f"{ncuts_loss:.4f}",
        }

        if worker.dice_values:
            postfix["val_dice"] = f"{worker.dice_values[-1]:.4f}"

        pbar.set_description(f"Epoch {idx}")
        pbar.set_postfix(postfix)
        pbar.update(1)

pbar.close()
print("Training complete.")


[Train] Starting from pretrained WNet weights.


Training:   0%|          | 0/50 [00:00<?, ?epoch/s]

monai.transforms.spatial.dictionary Orientationd.__init__:labels: Current default value of argument `labels=(('L', 'R'), ('P', 'A'), ('I', 'S'))` was changed in version None from `labels=(('L', 'R'), ('P', 'A'), ('I', 'S'))` to `labels=None`. Default value changed to None meaning that the transform now uses the 'space' of a meta-tensor, if applicable, to determine appropriate axis labels.
Loading dataset: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 19/19 [00:00<00:00, 270141.61it/s]
`data_array` is not of type `MetaTensor, assuming affine to be identity.
Loading dataset: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4/4 [00:00<00:00, 36.27it/s]
You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default

Training complete.


## Inference

In [None]:
import torch
import numpy as np
from copy import deepcopy
from tifffile import imread, imwrite
from napari_cellseg3d.dev_scripts import remote_inference as cs3d
from napari_cellseg3d.config import ModelInfo, WeightsInfo
from pathlib import Path

test_vol = tif_root / "test" / "vol"
pred_root = base_out / "preds" / "test"
pred_root.mkdir(parents=True, exist_ok=True)

# load your checkpoint
ckpt = save_dir / "wnet_best_metric.pth"
assert ckpt.exists()

# Build inference configuration
inference_config = deepcopy(cs3d.CONFIG)
inference_config.model_info = ModelInfo(
    name="WNet3D",
    model_input_size=[64, 64, 64],
    num_classes=2,
)
inference_config.weights_config = WeightsInfo(
    path=str(ckpt),
    use_pretrained=False,
    use_custom=True,
)

wcfg = inference_config.weights_config

if wcfg.use_custom and wcfg.path:
    print(f"[CellSeg3D] Using FINETUNED custom weights from: {wcfg.path}")
elif wcfg.use_pretrained:
    print("[CellSeg3D] Using pretrained WNet weights (built-in).")
else:
    print("[CellSeg3D] WARNING: No custom or pretrained weights selected â€“ using random init!")


pp_config = cs3d.PostProcessConfig()

def infer_one(vol):
    out = cs3d.inference_on_images(vol, config=inference_config)[0]
    sem = out.semantic_segmentation
    if sem.ndim == 4:
        sem = sem[1]   # take nuclei channel
    inst, _ = cs3d.post_processing(sem, config=pp_config)
    return sem.astype(np.uint8), inst.astype(np.uint16)

for vol_path in sorted(test_vol.glob("*.tif")):
    print("Processing", vol_path.name)
    vol = imread(str(vol_path)).astype(np.float32)
    vmin, vmax = vol.min(), vol.max()
    vol_norm = (vol - vmin) / (vmax - vmin) if vmax > vmin else vol

    sem, inst = infer_one(vol_norm)

    base = vol_path.stem
    imwrite(pred_root / f"{base}_semantic.tif", sem)
    imwrite(pred_root / f"{base}_instances.tif", inst)

print("Done! Predictions saved to", pred_root)


[CellSeg3D] Using FINETUNED custom weights from: /midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning/checkpoints/wnet_best_metric.pth
Processing patch_004_vol004_ch0.tif
--------------------
Parameters summary :
Model is : WNet3D
Window inference is enabled
Window size is 64
Window overlap is 0.25
Dataset loaded on cuda device
--------------------
MODEL DIMS : [64, 64, 64]
Model name : WNet3D
Instantiating model...
Loading weights...
Weights status : None
Done
--------------------
Parameters summary :
Model is : WNet3D
Window inference is enabled
Window size is 64
Window overlap is 0.25
Dataset loaded on cuda device
--------------------
Loading layer
2025-11-28 14:53:15,456 - INFO - Apply pending transforms - lazy: False, pending: 0, upcoming 'function', transform is not lazy
2025-11-28 14:53:15,458 - INFO - Apply pending transforms - lazy: False, pending: 0, upcoming 'ToTensor', transform is not lazy
2025-11-28 14:53:15,462 - INFO - Apply pending transforms - lazy

You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.


Post-processing...


INFO:napari_cellseg3d.dev_scripts.remote_inference:Thresholding with 0.4
INFO:napari_cellseg3d.dev_scripts.remote_inference:Clearing large objects with 500


Layer prediction saved as : volume_WNet3D_pred_1_2025_11_28_14_53_15


1it [00:00, 12.02it/s]
INFO:napari_cellseg3d.dev_scripts.remote_inference:Running instance segmentation with 0.55 and 0.55
INFO:napari_cellseg3d.dev_scripts.remote_inference:Clearing small objects with 5


size: 96


divide by zero encountered in scalar divide
invalid value encountered in scalar multiply


Processing patch_006_vol006_ch0.tif
--------------------
Parameters summary :
Model is : WNet3D
Window inference is enabled
Window size is 64
Window overlap is 0.25
Dataset loaded on cuda device
--------------------
MODEL DIMS : [64, 64, 64]
Model name : WNet3D
Instantiating model...
Loading weights...
Weights status : None
Done
--------------------
Parameters summary :
Model is : WNet3D
Window inference is enabled
Window size is 64
Window overlap is 0.25
Dataset loaded on cuda device
--------------------
Loading layer
2025-11-28 14:53:16,613 - INFO - Apply pending transforms - lazy: False, pending: 0, upcoming 'function', transform is not lazy
2025-11-28 14:53:16,615 - INFO - Apply pending transforms - lazy: False, pending: 0, upcoming 'ToTensor', transform is not lazy
2025-11-28 14:53:16,618 - INFO - Apply pending transforms - lazy: False, pending: 0, upcoming 'EnsureType', transform is not lazy
Done
----------
Inference started on layer...


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.


Post-processing...


INFO:napari_cellseg3d.dev_scripts.remote_inference:Thresholding with 0.4
INFO:napari_cellseg3d.dev_scripts.remote_inference:Clearing large objects with 500


Layer prediction saved as : volume_WNet3D_pred_1_2025_11_28_14_53_16


1it [00:00, 14.06it/s]
INFO:napari_cellseg3d.dev_scripts.remote_inference:Running instance segmentation with 0.55 and 0.55


size: 96


INFO:napari_cellseg3d.dev_scripts.remote_inference:Clearing small objects with 5


Done! Predictions saved to /midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning/preds/test


## Check to ensure pretrained weights are used correctly

In [None]:
import torch
from napari_cellseg3d.code_models.models.wnet.model import WNet

# --- 1. Load checkpoint ---
ckpt_path = "/midtier/paetzollab/scratch/ads4015/compare_methods/cellseg3d/finetuning/checkpoints/wnet_best_metric.pth"
state_dict = torch.load(ckpt_path)

# pick a parameter to compare
param_name = "encoder.conv1.module.0.weight"

# Value inside checkpoint
ckpt_val = state_dict[param_name][0,0,0,0,0].item()
print("Checkpoint value:", ckpt_val)

# --- 2. Instantiate model with correct architecture (matching training) ---
model_finetuned = WNet(
    in_channels=1,
    out_channels=1,   # important!!
    num_classes=2,
    dropout=0.65,
)

# Load finetuned weights into the model
model_finetuned.load_state_dict(state_dict, strict=True)

# Extract the same parameter from the model
model_finetuned_val = model_finetuned.encoder.conv1.module[0].weight[0,0,0,0,0].item()
print("Finetuned model value:", model_finetuned_val)

# Check match between checkpoint and loaded model
print("Match finetuned? :", ckpt_val == model_finetuned_val)

# --- 3. Create a randomly initialized model for comparison ---
model_random = WNet(
    in_channels=1,
    out_channels=1,
    num_classes=2,
    dropout=0.65,
)

# Extract the same parameter from the random model
model_random_val = model_random.encoder.conv1.module[0].weight[0,0,0,0,0].item()
print("Random model value:", model_random_val)

# Check if random model equals finetuned model
print("Random equals finetuned? :", model_random_val == model_finetuned_val)


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.


Checkpoint value: 2.155968786610174e-06
Finetuned model value: 2.155968786610174e-06
Match finetuned? : True
Random model value: -0.015734711661934853
Random equals finetuned? : False
