# Script to segment your 3D data 
3D Image Segmentation Script using DL-pixel classification based on U-net

This script performs slice-by-slice inference on volumetric microscopy images 
using a pretrained PyTorch model. It supports multiple input formats including 
CZI, TIFF, and OME-TIFF, and produces multi-channel OME-TIFF outputs containing 
either classification masks or class probability maps.

Features:
- Sliding window inference for large 2D slices (Z-stack compatible)
- Export options: classification masks (binary per class) or probability maps
- Selectable output classes (EXPORT_CLASSES)
- Memory-efficient progressive writing to avoid RAM overload
- Output format compatible with Fiji, Napari, Imaris, etc.

Note:
Class indices start at 0 (PyTorch convention), while annotation tools like Napari 
typically use labels starting at 1. Adjust EXPORT_CLASSES accordingly.

In [1]:
## import librairies
import torch
import sys
import numpy as np
from pathlib import Path
import tifffile
from tifffile import TiffWriter, imwrite, TiffFile
import czifile
import gc
from tqdm import tqdm
from monai.inferers import sliding_window_inference
from tnia.deeplearning.dl_helper import quantile_normalization
import torch.nn.functional as F

raster_geometry not imported.  This is only needed for the ellipsoid rendering in apply_stardist


### Configuration section
Define base directories, model path, input/output folders. Make sure these paths and options match your system and use case.

In [2]:
BASE_PATH = Path(r'C:/Users/Alex/Desktop/Mailis')
DATA_DIR = BASE_PATH / "data"
MODEL_PATH = BASE_PATH / "models" / "vessel_final_3.pth"
OUTPUT_DIR = BASE_PATH / "vessel_binary"

ROI_SIZE = 512
BATCH_SIZE = 1
SUPPORTED_EXTENSIONS = [
    "*.czi", "*.tif", "*.tiff", "*.ome.tif", "*.ome.tiff", "*.btf", "*.ome.btf"
]

### Model initialization and Count number of trained classes
Load the trained model and move it to the appropriate device (CPU or GPU). Then, detect how many output classes the model was trained with.

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load(MODEL_PATH, weights_only=False).to(device)
model.eval() # Switch the model to evaluation mode so layers like BatchNorm and Dropout behave correctly during inference.               
torch.set_grad_enabled(False)
# --- Detect number of output classes from the model ---
with torch.no_grad():
    dummy_input = torch.zeros(1, 1, ROI_SIZE, ROI_SIZE, device=device)
    dummy_output = sliding_window_inference(dummy_input, roi_size=ROI_SIZE, sw_batch_size=1, predictor=model)
    num_model_classes = dummy_output.shape[1]

print(f"Model was trained with {num_model_classes} classes (class indices: 0 to {num_model_classes - 1}).")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

Model was trained with 3 classes (class indices: 0 to 2).


### Choose output mode and classes to export

Choose between "classification" (argmax per pixel) and "probability" (raw softmax scores).
Then, select which class indices you want to export in the final output.

NOTE:
In Napari, annotation labels usually start at 1. However, model outputs are indexed from 0. So Napari class 1 corresponds to model class index 0.

In [4]:
EXPORT_MODE = "probability"  # "classification" or "probability"
EXPORT_CLASSES = [0,1,2]       # These are model class indices (ex [0, 1],..., [1, 2, 3] )


### Prediction

In [None]:
# IMAGE LOADING

def load_image(path: Path) -> np.ndarray:
    ext = path.suffix.lower()
    if ext == ".czi":
        image = czifile.imread(path)
    elif ext in [".tif", ".tiff", ".ome.tif", ".ome.tiff", ".btf", ".ome.btf"]:
        image = tifffile.imread(path)
    else:
        raise ValueError(f"Unsupported file format: {ext}")
    
    image = np.squeeze(image)
    if image.ndim != 3:
        raise ValueError(f"Expected 3D image (Z, Y, X), got shape {image.shape}")
    return image


# PREDICTION

def predict(image_2d: np.ndarray, model: torch.nn.Module) -> np.ndarray:
    image = quantile_normalization(image_2d).astype(np.float32)
    tensor = torch.from_numpy(image).unsqueeze(0).unsqueeze(0).to(device)

    with torch.no_grad():
        logits = sliding_window_inference(
            tensor,
            roi_size=ROI_SIZE,
            sw_batch_size=BATCH_SIZE,
            predictor=model
        )
        probabilities = F.softmax(logits, dim=1)

    return probabilities.squeeze(0).cpu().numpy()  # Shape: (C, H, W)


# PROCESS A SINGLE FILE

def process_file(image_path: Path, file_index: int, total_files: int):
    print(f"\n[{file_index + 1}/{total_files}] Processing: {image_path.name}")
    image_stack = load_image(image_path)
    depth, height, width = image_stack.shape

    test_probs = predict(image_stack[0], model)
    total_classes = test_probs.shape[0]

    for cls in EXPORT_CLASSES:
        if cls < 0 or cls >= total_classes:
            raise ValueError(f"Invalid class index {cls}. Model returns {total_classes} classes.")

    suffix = "classification" if EXPORT_MODE == "classification" else "probability"
    save_path = OUTPUT_DIR / f"{image_path.stem}_{suffix}.ome.tif"

    print(f"Predicting {depth} slices with {len(EXPORT_CLASSES)} channels...")

    # Allocate array for final result: shape (C, Z, Y, X)
    output_stack = np.zeros((len(EXPORT_CLASSES), depth, height, width), dtype=np.uint8)

    for z in range(depth):
        prob_map = predict(image_stack[z], model)

        if EXPORT_MODE == "probability":
            out_slice = np.stack([
                (prob_map[cls] * 255).astype(np.uint8)
                for cls in EXPORT_CLASSES
            ], axis=0)

        elif EXPORT_MODE == "classification":
            class_map = np.argmax(prob_map, axis=0)
            out_slice = np.stack([
                ((class_map == cls).astype(np.uint8)) * 255
                for cls in EXPORT_CLASSES
            ], axis=0)

        output_stack[:, z, :, :] = out_slice  # Fill slice for all classes at Z=z

        # Progress bar
        bar_length = 30
        progress = int((z + 1) / depth * bar_length)
        bar = '█' * progress + '-' * (bar_length - progress)
        sys.stdout.write(f"\r[{bar}] Slice {z + 1}/{depth}")
        sys.stdout.flush()

    # Write all slices in one go
    print(f"\nSaving to: {save_path}")
    tifffile.imwrite(
        save_path,
        data=output_stack,
        photometric='minisblack',
        metadata={'axes': 'CZYX'},
        bigtiff=True
    )
    print(f" Saved OME-TIFF: {save_path}")
    
    # garbage 
    del image_stack
    del output_stack
    torch.cuda.empty_cache()
    gc.collect()
       


# MAIN SCRIPT

def main():
    input_files = []
    for ext in SUPPORTED_EXTENSIONS:
        input_files.extend(DATA_DIR.glob(ext))
    input_files = sorted(input_files)

    if not input_files:
        print("No input image files found in the specified folder.")
        return

    print(f"{len(input_files)} image file(s) detected.")
    for idx, image_file in enumerate(input_files):
        print("")  # spacing
        process_file(image_file, idx, len(input_files))


if __name__ == "__main__":
    main()


13 image file(s) detected.


[1/13] Processing: M20E4-Stitching-09-Create Image Subset-04.czi
