In [36]:
import numpy as np
from cellpose import models, core, io
from spotiflow.model import Spotiflow
from pathlib import Path
from pathlib import Path
import napari
import apoc
from tqdm import tqdm
import pyclesperanto_prototype as cle 
from tifffile import imwrite, imread
from utils import list_images, read_image

io.logger_setup() # run this to get printing of progress

#Check if colab notebook instance has GPU access
if core.use_gpu()==False:
  raise ImportError("No GPU access, change your runtime")

#Load pre-trained Cellpose-SAM and Spotiflow models
model = models.CellposeModel(gpu=True)
spotiflow_model = Spotiflow.from_pretrained("general")

creating new log file
2025-09-01 13:02:52,166 [INFO] WRITING LOG OUTPUT TO C:\Users\adiez_cmic\.cellpose\run.log
2025-09-01 13:02:52,166 [INFO] 
cellpose version: 	4.0.6 
platform:       	win32 
python version: 	3.10.18 
torch version:  	2.5.0
2025-09-01 13:02:52,173 [INFO] ** TORCH CUDA version installed and working. **
2025-09-01 13:02:52,176 [INFO] ** TORCH CUDA version installed and working. **
2025-09-01 13:02:52,176 [INFO] >>>> using GPU (CUDA)
2025-09-01 13:02:53,554 [INFO] >>>> loading model C:\Users\adiez_cmic\.cellpose\models\cpsam
INFO:spotiflow.model.spotiflow:Loading pretrained model: general
2025-09-01 13:02:54,256 [INFO] Loading pretrained model: general


In [26]:
# Copy the path where your images are stored, you can use absolute or relative paths to point at other disk locations
directory_path = Path("X:\Lisa\siMtb screen I_LØ\Plate 01_Nuc")

# Iterate through the .czi and .nd2 files in the directory
images = list_images(directory_path)

# Image size reduction (downsampling) to improve processing times (slicing, not lossless compression)
slicing_factor_xy = None # Use 2 or 4 for downsampling in xy (None for lossless)

images

['X:\\Lisa\\siMtb screen I_LØ\\Plate 01_Nuc\\Plate01_Nuc_Wells-A1__Channel_SD_AF647,SD_RFP,SD_GFP,SD_DAPI,SD_BF,SD_NIR.nd2',
 'X:\\Lisa\\siMtb screen I_LØ\\Plate 01_Nuc\\Plate01_Nuc_Wells-A2__Channel_SD_AF647,SD_RFP,SD_GFP,SD_DAPI,SD_BF,SD_NIR.nd2',
 'X:\\Lisa\\siMtb screen I_LØ\\Plate 01_Nuc\\Plate01_Nuc_Wells-A3__Channel_SD_AF647,SD_RFP,SD_GFP,SD_DAPI,SD_BF,SD_NIR.nd2',
 'X:\\Lisa\\siMtb screen I_LØ\\Plate 01_Nuc\\Plate01_Nuc_Wells-A4__Channel_SD_AF647,SD_RFP,SD_GFP,SD_DAPI,SD_BF,SD_NIR.nd2',
 'X:\\Lisa\\siMtb screen I_LØ\\Plate 01_Nuc\\Plate01_Nuc_Wells-A5__Channel_SD_AF647,SD_RFP,SD_GFP,SD_DAPI,SD_BF,SD_NIR.nd2',
 'X:\\Lisa\\siMtb screen I_LØ\\Plate 01_Nuc\\Plate01_Nuc_Wells-A6__Channel_SD_AF647,SD_RFP,SD_GFP,SD_DAPI,SD_BF,SD_NIR.nd2',
 'X:\\Lisa\\siMtb screen I_LØ\\Plate 01_Nuc\\Plate01_Nuc_Wells-A7__Channel_SD_AF647,SD_RFP,SD_GFP,SD_DAPI,SD_BF,SD_NIR.nd2',
 'X:\\Lisa\\siMtb screen I_LØ\\Plate 01_Nuc\\Plate01_Nuc_Wells-A8__Channel_SD_AF647,SD_RFP,SD_GFP,SD_DAPI,SD_BF,SD_NIR.nd2',


In [27]:
viewer = napari.Viewer(ndisplay=2)

In [28]:
#TODO: Substract uneven and remove background from BF by obtaining the median of all BF channels 

try:

    #TODO: Change to imread(directory_path / "bf_correction.tiff")
    bf_correction = imread("./raw_data/test_data_2/bf_correction.tiff")

except FileNotFoundError:

    # Create an empty list to store the brightfield images from each well
    bf_arrays = []

    # Read all images, extract the brightfield channel and calculate the mean to correct illumination and remove dust spots
    for image in tqdm(images):

        # Read image, apply slicing if needed and return filename and img as a np array
        img, filename = read_image(image, slicing_factor_xy, log=False)

        # Extract brighfield slice
        bf_channel = img[4]

        # Add it to the bf_arrays iterable
        bf_arrays.append(bf_channel)

    # Create a stack containing all bf images
    bf_stack = np.stack(bf_arrays, axis=0)

    # Calculate the median to retain the common structures (spots, illumination)
    bf_correction = np.median(bf_stack, axis=0)
    del bf_stack

    # Store brightfield correction as .tiff to avoid recalculating it everytime
    imwrite("./raw_data/test_data_2/bf_correction.tiff",bf_correction)

viewer.add_image(bf_correction)

<Image layer 'bf_correction' at 0x1c895a99990>

In [None]:
# Explore a different image (0 defines the first image in the directory)
image = images[1]

# Read image, apply slicing if needed and return filename and img as a np array
img, filename = read_image(image, slicing_factor_xy)

# Extract plate number and well_id
plate_nr = filename.split("_")[0]
well_id = filename.split("-")[1][:2]

viewer.add_image(img)

Compressed Array shape: (6, 5032, 5032)


Image analyzed: Plate01_Nuc_Wells-A2__Channel_SD_AF647,SD_RFP,SD_GFP,SD_DAPI,SD_BF,SD_NIR
Original Array shape: (6, 5032, 5032)


<Image layer 'img' at 0x1c7f5b4e7d0>

In [30]:
nuclei_labels, flows, styles = model.eval(img[-1:], niter=1000) # need to check the arguments
viewer.add_labels(nuclei_labels)

<Labels layer 'nuclei_labels' at 0x1c88dd031f0>

In [31]:
cytoplasm_labels, flows, styles = model.eval(np.stack((img[[0,1]].sum(axis=0), (img[4] - bf_correction)), axis=0), niter=1000) # need to check the arguments
viewer.add_labels(cytoplasm_labels)

<Labels layer 'cytoplasm_labels' at 0x1c891e22bc0>

In [32]:
points, details = spotiflow_model.predict(img[0], subpix=True)
viewer.add_image(details.heatmap, colormap="viridis")
viewer.add_points(points, face_color='red')

INFO:spotiflow.model.spotiflow:Will use device: cuda:0
2025-09-01 12:50:40,739 [INFO] Will use device: cuda:0
INFO:spotiflow.model.spotiflow:Predicting with prob_thresh = [0.6], min_distance = 1
2025-09-01 12:50:40,739 [INFO] Predicting with prob_thresh = [0.6], min_distance = 1
INFO:spotiflow.model.spotiflow:Peak detection mode: fast
2025-09-01 12:50:40,739 [INFO] Peak detection mode: fast
INFO:spotiflow.model.spotiflow:Image shape (5032, 5032)
2025-09-01 12:50:40,739 [INFO] Image shape (5032, 5032)
INFO:spotiflow.model.spotiflow:Predicting with (3, 3) tiles
2025-09-01 12:50:40,739 [INFO] Predicting with (3, 3) tiles
INFO:spotiflow.model.spotiflow:Normalizing...
2025-09-01 12:50:40,774 [INFO] Normalizing...
INFO:spotiflow.model.spotiflow:Padding to shape (5040, 5040, 1)
2025-09-01 12:50:41,000 [INFO] Padding to shape (5040, 5040, 1)


Predicting tiles: 100%|██████████| 9/9 [00:03<00:00,  3.00it/s]

INFO:spotiflow.model.spotiflow:Found 3341 spots
2025-09-01 12:50:44,042 [INFO] Found 3341 spots





<Points layer 'points' at 0x1c7e204dbd0>

In [37]:
#TODO: Bacterial detection, infection rate calculation

# Voronoi otsu labeling is very affected by outliers, train APOC object segmenter for best results
# Setup classifer and where it should be saved
cl_filename = "./raw_data/SD_DAPI_Mtb_detection_training/Mtb_segmenter.cl"
mtb_segmenter = apoc.ObjectSegmenter(opencl_filename=cl_filename)

mtb_labels = mtb_segmenter.predict(img[3])
mtb_labels = cle.pull(mtb_labels)

In [38]:
viewer.add_labels(mtb_labels)

<Labels layer 'mtb_labels [1]' at 0x1c88af308b0>

In [None]:
# Convert mtb_labels to boolean mask
mtb_boolean = mtb_labels.astype(bool)

# Find cytoplasm labels that intersect with the mtb signal
# This will only be necessary if we decide to perform an erosion operation later to get rid of partially touching bacteria in the edge of cells
cytoplasm_and_mtb = cytoplasm_labels & mtb_boolean

In [40]:
viewer.add_image(mtb_boolean)
viewer.add_image(cytoplasm_and_mtb)

<Image layer 'cytoplasm_and_mtb' at 0x1c814e25e40>

In [43]:
# Use NumPy's indexing to identify labels that intersect with mtb_boolean (bacterial mask)
infected_labels = np.unique(cytoplasm_labels[mtb_boolean])
infected_labels = infected_labels[infected_labels != 0]

In [44]:
infected_labels

array([   3,    6,   10,   17,   23,   25,   27,   31,   34,   35,   41,
         43,   44,   45,   56,   62,   64,   67,   81,   83,   92,   93,
         97,  100,  109,  110,  114,  120,  121,  123,  126,  130,  132,
        135,  136,  140,  148,  151,  152,  156,  165,  174,  180,  182,
        183,  190,  192,  198,  203,  204,  205,  208,  215,  218,  223,
        225,  227,  231,  235,  237,  251,  256,  262,  268,  287,  289,
        291,  292,  300,  306,  313,  317,  322,  324,  336,  344,  347,
        350,  371,  379,  381,  383,  385,  386,  395,  396,  398,  402,
        409,  413,  419,  422,  425,  437,  438,  447,  448,  450,  455,
        458,  460,  461,  462,  464,  465,  466,  468,  480,  482,  487,
        490,  504,  507,  509,  512,  513,  516,  517,  518,  524,  538,
        540,  541,  545,  547,  553,  554,  560,  563,  572,  573,  577,
        581,  586,  596,  598,  609,  621,  623,  630,  631,  654,  670,
        681,  684,  686,  687,  692,  702,  704,  7

In [48]:
infected_mask = np.isin(cytoplasm_labels, infected_labels)
non_infected_mask = np.isin(cytoplasm_labels, infected_labels, invert=True)
infected_cytoplasm = np.where(infected_mask, cytoplasm_labels, 0).astype(cytoplasm_labels.dtype)
non_infected_cytoplasm = np.where(non_infected_mask, cytoplasm_labels, 0).astype(cytoplasm_labels.dtype)

In [49]:
viewer.add_labels(infected_mask)
viewer.add_labels(infected_cytoplasm)
viewer.add_labels(non_infected_cytoplasm)

<Labels layer 'non_infected_cytoplasm' at 0x1c7f71fba30>

In [71]:
infected_cells = len(np.unique(infected_cytoplasm)) - (0 in infected_cytoplasm)
non_infected_cells = len(np.unique(non_infected_cytoplasm)) - (0 in non_infected_cytoplasm)
total_cells = cytoplasm_labels.max()


In [72]:
print(f"Non-infected: {non_infected_cells}")
print(f"Infected: {infected_cells}")
print(f"Total cells: {total_cells}")

Non-infected: 2217
Infected: 814
Total cells: 3031


In [73]:
# Calculate percentage of infected cells
perc_inf_cells = round(infected_cells / total_cells * 100, 2) if total_cells > 0 else 0
print(perc_inf_cells)

26.86
