In [None]:
import napari
import czifile
import numpy as np
import matplotlib.pyplot as plt
import torch
from pathlib import Path
from monai.inferers import sliding_window_inference
from tnia.deeplearning.dl_helper import quantile_normalization
import torch.nn.functional as F
import tifffile
from tqdm import tqdm


In [None]:
# Define the parent path and list of input CZI files
parent_path = Path(r'C:/Users/Alex/Desktop/Mailis')
input_files = sorted((parent_path / "data").glob("*.czi"))
print(f"Found {len(input_files)} files:")
for f in input_files:
    print(f.name)


In [None]:
# Load model
models_path = parent_path / 'models'
# put the name of your model.pth
net = torch.load(Path(models_path) / 'full_brain.pth', weights_only=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = net.to(device)


In [None]:
# Define the prediction function that returns class probability maps (multi-channel)
def predict_multichannel(im, net):
    im = quantile_normalization(im)
    im = im.astype(np.float32)
    im_tensor = torch.from_numpy(im).unsqueeze(0).unsqueeze(0).to(device)

    with torch.no_grad():
        logits = sliding_window_inference(
            im_tensor,
            1024,
            1,
            net
        )

    probabilities = torch.nn.functional.softmax(logits, dim=1)
    prob_np = probabilities.squeeze(0).cpu().numpy()  # Shape: (C, H, W)
    return prob_np


In [None]:
# Process each CZI file and save multi-channel predictions as OME-TIFF
output_dir = parent_path / "Brain_volume"
output_dir.mkdir(parents=True, exist_ok=True)

for czi_path in input_files:
    print(f"Processing: {czi_path.name}")
    image = czifile.imread(czi_path)
    image = np.squeeze(image)
    save_path = output_dir / (czi_path.stem + "_prediction_multichannel.ome.tif")

    sample_prediction = predict_multichannel(image[0, :, :], net)
    num_classes = sample_prediction.shape[0]
    depth = image.shape[0]
    height, width = sample_prediction.shape[1:]

    # Initialize an empty array for all slices and classes: (C, Z, Y, X)
    all_predictions = np.zeros((num_classes, depth, height, width), dtype=np.uint8)

    for i in tqdm(range(depth), desc=f"Predicting {czi_path.name}"):
        image2d = image[i, :, :]
        prediction = predict_multichannel(image2d, net)
        prediction = (prediction * 255).astype(np.uint8)  # Scale probabilities to 0–255
        all_predictions[:, i, :, :] = prediction

    # Save as OME-TIFF with shape (C, Z, Y, X)
    tifffile.imwrite(
        save_path,
        all_predictions,
        photometric='minisblack',
        metadata={'axes': 'CZYX'},
        bigtiff=True
    )

    print(f"Saved multi-channel OME-TIFF prediction to {save_path}")
