In [None]:
## IMPORT LIBRARIES ##

import os
import numpy as np
import pandas as pd
import json
import torch
from torchvision import transforms
from tifffile import imwrite
from aicsimageio import AICSImage
from cellpose import models
from skimage import img_as_float, exposure
from skimage.io import imread, imsave
from skimage.measure import regionprops
from skimage.util import img_as_ubyte
from skimage.transform import resize
from torch.utils.data import Dataset, DataLoader
from skimage.color import label2rgb
from skimage.segmentation import find_boundaries
from scipy import ndimage
from scipy.ndimage import median_filter

if torch.cuda.is_available() is True:
    try:
        import cupy as cu
        from cucim.skimage.morphology import dilation, disk
        cuda = True
    except ImportError:
        from skimage.morphology import dilation, disk
        cuda = False
else:
    from skimage.morphology import dilation, disk
    cuda = False

In [None]:
## ENTER YOUR VALUES ##

# Directories
parent_directory = "/path/to/your/folder/" # Replace with the actual path
cellpose_directory = "/path/to/your/folder/" # Replace with the actual path
model_name = "MyoFuse.pth" # Replace with the name of your classifier

# Pre-processing
myotube_channel = 1 # Replace with the actual myotube channel in your images
nuclei_channel = 0 # Replace with the actual nuclei channel in your images
extension = (".tif", ".tiff")

# Segmentation parameters
dia = 24 # Adjust this value depending on your images. 
         # Set it to 0 for Cellpose to automatically determine the best value.

# Classification parameters
half_patch_size = 100
batch_size = 512

# Prediction parameters
save_prediction = True  # Set to False to disable saving of prediction images
downsampling_factor = 1 # Downsampling factor for prediction image. 1 : Full resolution.

# Define paths
full_images_folder = os.path.join(parent_directory, "Full Images")
myotube_folder = os.path.join(parent_directory, "Myotubes")
nuclei_folder = os.path.join(parent_directory, "Nuclei")
masks_folder = os.path.join(parent_directory, "Masks")
norm_folder = os.path.join(parent_directory, "Images")
predictions_output_dir = os.path.join(parent_directory, "Predictions")
model_path = os.path.join(parent_directory, 'Svetlana', model_name)
config_path = os.path.join(parent_directory, "Svetlana", "Config.json")

# Ensure output folders exists
os.makedirs(myotube_folder, exist_ok=True)
os.makedirs(nuclei_folder, exist_ok=True)
os.makedirs(masks_folder, exist_ok=True)
os.makedirs(norm_folder, exist_ok=True)
os.makedirs(predictions_output_dir, exist_ok=True)

In [None]:
## IMAGE SPLITTING ## 
'''
IF NECESSARY - Split images into separate channels for processing
'''
## Process each image in "Full Images"
if os.path.exists(full_images_folder):
    for file in os.listdir(full_images_folder):
        file_path = os.path.join(full_images_folder, file)

        # Check if it's an image file
        if file.lower().endswith((".tiff", ".tif")):
            image = imread(file_path)

        elif file.lower().endswith((".ome.tiff", ".czi")):
            # Désactiver la reconstruction de la mosaïque pour éviter le problème de broadcasting
            aics_image = AICSImage(file_path, reconstruct_mosaic=False)
            image = aics_image.data[0]  # Récupérer les données pour la première scène
            print(image.shape)
            if image.ndim == 5:  # TCZYX
                image = image[0, 0].transpose((1, 2, 0))
            elif image.ndim == 4:  # CZYX
                image = image[0].transpose((1, 2, 0))
            else:
                print(f"Skipped {file}: unsupported file format.")
                continue

        # Check if the image has at least 2 channels
        if image.ndim == 3 and image.shape[2] >= 2:
            myotube = image[myotube_channel, :, :]
            nuclei = image[nuclei_channel, :, :]

            # Save Myotube channel
            myotube_save_path = os.path.join(myotube_folder, f"{os.path.splitext(file)[0]}.tif")
            imsave(myotube_save_path, myotube.astype(np.uint16))

            # Save Nuclei channel
            nuclei_save_path = os.path.join(nuclei_folder, f"{os.path.splitext(file)[0]}.tif")
            imsave(nuclei_save_path, nuclei.astype(np.uint16))

            print(f"Processed and saved channels for: {file}")
        else:
            print(f"Skipped {file}: not enough channels.")
else:
    print(f"The folder 'Full Images' does not exist in {parent_directory}")

In [None]:
## SEGMENTATION ##

# List to store all image file paths
image_files = []

# Check if the "Nuclei" folder exists
if os.path.exists(nuclei_folder):
    for file in os.listdir(nuclei_folder):
        file_path = os.path.join(nuclei_folder, file)
        # Check if the file is an image
        if file.lower().endswith(( ".tiff", ".tif")):
            image_files.append(file_path)
else:
    print(f"The folder 'Nuclei' does not exist in {parent_directory}")

# Print the list of images
for index, file in enumerate(image_files):
    print(f"{file}")

    def load_image(file_path):
    
        # Load a single channel image
        image = AICSImage(file_path).data[0][0] 
        return image

    def loop(file_path, parent_directory, diameter):
        base_name = os.path.basename(file_path)
        name_without_ext = os.path.splitext(base_name)[0]
        
        # Load image
        img = load_image(file_path)
    
        # Cellpose
        model = models.CellposeModel(
            gpu=False,
            pretrained_model= cellpose_directory)
        masks, flows, styles = model.eval(img, diameter, channels=[0, 0], normalize=True)
    
         # Save mask as .tiff
        if masks.max() > 0:
            mask_save_path = os.path.join(masks_folder, name_without_ext + ".tif")
            imwrite(mask_save_path, masks.astype(np.uint32), compression='zlib')
            print(f"Mask saved: {mask_save_path}")
        else:
            print(f"No cells detected in {file_path}, no mask has been saved.")
    
        result = {
            'image': file_path,
            'diameter': diameter,
            'cells count': np.max(masks)
        }

        # Empty GPU cache
        torch.cuda.empty_cache()
    
        return result

# List all image files in the "Nuclei" folder
image_files = [f for f in os.listdir(nuclei_folder) if f.lower().endswith(('.tif', '.tiff'))]

# Process each image
for file in image_files:
    file_path = os.path.join(nuclei_folder, file)
    print(f"Processing file: {file_path}")
    result = loop(file_path, parent_directory, dia)                                                                                                                     

In [None]:
## CLASSIFICATION PRE-PROCESSING ##

## Normalize image

def normalize_image_first_percentile(image, percentile_min, percentile_max):
    
    # Extract non-zero pixels
    nonzero_pixels = image[image > 0]
    if nonzero_pixels.size == 0:
        return image  # Nothing to normalize if the image only contains zeros

    # Define the new minimum based on the chosen percentile among non-zero pixels
    new_min = np.percentile(nonzero_pixels, percentile_min)
    max_val = np.percentile(nonzero_pixels, percentile_max)

    # Create a normalized image initialized to 0 (to keep pixels at 0 unchanged)
    norm_image = np.zeros_like(image)

    # Apply normalization only for pixels greater than or equal to new_min
    image[image < new_min] = 0
    mask = image >= new_min
    norm_image[mask] = (image[mask] - new_min) / (max_val - new_min)

    # Clip to ensure values stay within [0, 1]
    norm_image = np.clip(norm_image, 0, 1)

    return norm_image

def save_labels_images(
    labels, 
    myotube_image, 
    output_path,
    percentile=0
):
    if myotube_image.ndim == 3 and myotube_image.shape[2] == 3:
        myotube_image = np.mean(myotube_image, axis=2)
    
    # Ensure image is kept at full resolution and original dtype
    myotube_image = myotube_image.astype(np.float32)  # Preserve precision

    # Apply percentile-based normalization
    percentile_min = 4
    percentile_max = 99.7
    myotube_image = normalize_image_first_percentile(myotube_image, percentile_min, percentile_max)

    # Save the normalized image
    imsave(output_path, myotube_image, check_contrast=False)  # Preserve quality

    print(f"Processed {output_path}")

def process_folder(myotube_folder, labels_folder, output_folder):
    os.makedirs(output_folder, exist_ok=True)
    
    myotube_files = sorted(f for f in os.listdir(myotube_folder) if f.lower().endswith(('.tif', '.tiff')))
    
    for file in myotube_files:
        myotube_path = os.path.join(myotube_folder, file)
        label_path = os.path.join(labels_folder, file)
        output_path = os.path.join(output_folder, file)
        
        if not os.path.exists(label_path):
            print(f"Label missing for {file}, skipping.")
            continue
        
        myotube_image = imread(myotube_path)
        labels = imread(label_path)
        
        save_labels_images(labels, myotube_image, output_path)
        
process_folder(myotube_folder, masks_folder, norm_folder)

In [None]:
## CLASSIFICATION ##

## Collect label centroids

def max_to_one(im):
    
    im = im / im.max()

    return im

def compute_label_centroids(label_image):
    
    label_image = label_image.astype(np.int32, copy=False)
    labels_unique = np.unique(label_image)
    labels_unique = labels_unique[labels_unique != 0]

    weights = np.ones_like(label_image, dtype=np.float32)
    cm = ndimage.center_of_mass(weights, labels=label_image, index=labels_unique)

    return labels_unique, cm

## Load images, masks

def load_images_labels_centroids(parent_directory):

    image_files = sorted([f for f in os.listdir(norm_folder) if f.lower().endswith(('.tif', '.tiff'))])
    
    images, labels, all_label_ids, all_centroids = [], [], [], []

    for img_file in image_files:
        img_path = os.path.join(norm_folder, img_file)
        mask_path = os.path.join(masks_folder, img_file)

        if not os.path.exists(mask_path):
            print(f"No corresponding mask for {img_file}, skipping.")
            continue

        # Load Image 
        image_data = imread(img_path).astype(np.float32)
        
        if image_data.ndim == 2:
            # Duplique en 3 canaux
            image_data = np.stack([image_data]*3, axis=-1)

        # Convert in float32
        image_data = image_data.astype(np.float32, copy=False)

        # Load Mask
        mask_data = imread(mask_path)

        # Calculate centroids
        label_ids, centroids = compute_label_centroids(mask_data)

        # Store
        images.append(image_data)
        labels.append(mask_data)
        all_label_ids.append(label_ids)
        all_centroids.append(centroids)

    return image_files, images, labels, all_label_ids, all_centroids

## Prediction functions

class PredictionDataset(Dataset):
    """
      prop.centroid = (row, col)  =>  cx = row, cy = col  
    """

    def __init__(self, image, labels, label_ids, centroids, half_patch_size, device, config_dict):
        super().__init__()
        self.image = image
        self.labels = labels
        self.label_ids = label_ids
        self.centroids = centroids
        self.half_patch_size = half_patch_size
        self.device = device
        self.config_dict = config_dict

        self.transform = transforms.Compose([
            transforms.ToTensor()
        ])

        pad = half_patch_size + 1

        self.image = np.pad(
            self.image,
            ((pad, pad), (pad, pad), (0, 0)),
            mode="constant"
        )
        if self.labels.dtype != np.int32:
            self.labels = self.labels.astype(np.int32, copy=False)
        self.labels = np.pad(
            self.labels,
            ((pad, pad), (pad, pad)),
            mode="constant"
        )

    def __len__(self):
        return len(self.label_ids)

    def __getitem__(self, idx):
        try:
           
            row, col = self.centroids[idx] 
            cx = int(row)
            cy = int(col)
            cx += self.half_patch_size + 1
            cy += self.half_patch_size + 1
            hps = self.half_patch_size
            xmin, xmax = cx - hps, cx + hps
            ymin, ymax = cy - hps, cy + hps

            # Extract patch
            patch_img = self.image[xmin:xmax, ymin:ymax, :].copy()
            patch_mask = self.labels[xmin:xmax, ymin:ymax].copy()

            label_id = self.label_ids[idx]

            # Normalize
            patch_img = max_to_one(patch_img)

            # Binarize
            patch_mask[patch_mask != label_id] = 0
            patch_mask[patch_mask == label_id] = 1

            # Dilation if stated in the config file
            do_dilate = json.loads(self.config_dict["options"]["dilation"]["dilate_mask"].lower())
            if do_dilate:
                se_size = int(self.config_dict["options"]["dilation"]["str_element_size"])
                strel = disk(se_size)
                patch_mask = dilation(patch_mask, strel)
                patch_img *= patch_mask[..., None]

            # Mask concatenation as 4th channel
            out = np.zeros((patch_img.shape[0], patch_img.shape[1], 4), dtype=np.float32)
            out[..., :3] = patch_img
            out[..., 3] = patch_mask

            # Transformation to (C, H, W) + Tensor
            out = self.transform(out)
            return out.to(self.device)

        except Exception as e:
            print(f"[Dataset] Erreur index {idx} : {e}")
            return None

## Save prediction image

def save_colored_predictions_downsample(
    labels, 
    predictions, 
    used_labels, 
    myotube_image, 
    output_path, 
    factor=1
):
    if myotube_image.ndim == 3 and myotube_image.shape[2] == 3:
        myotube_image = np.mean(myotube_image, axis=2)
    H, W = myotube_image.shape

    newH, newW = H//factor, W//factor
    myotube_ds = resize(myotube_image, (newH, newW),
                        preserve_range=True,
                        anti_aliasing=True)
    
    # Removing lowest and highest values
    p2, p98 = np.percentile(myotube_ds, (2, 98))

    myotube_ds = exposure.rescale_intensity(myotube_ds, in_range=(p2, p98), out_range=(0,1))

    myotube_rgb = np.stack([myotube_ds]*3, axis=-1)

    boundaries = find_boundaries(labels, mode='inner')
    boundary_labels = labels.copy()
    boundary_labels[~boundaries] = 0

    if len(predictions) != len(used_labels):
        raise ValueError("Nombre de prédictions != nombre de labels.")
    label_to_pred = dict(zip(used_labels, predictions))

    color_0 = [1.0, 0.0, 0.0]  # Rouge = Nuclei Out
    color_1 = [0.0, 1.0, 0.0]  # Vert = Nuclei In

    coords = np.column_stack(np.nonzero(boundary_labels))
    for (y, x) in coords:
        lb = boundary_labels[y, x]
        pred_class = label_to_pred.get(lb, None)
        if pred_class is not None:
            yd = y // factor
            xd = x // factor
            if yd < newH and xd < newW:
                if pred_class == 0:
                    myotube_rgb[yd, xd] = color_0
                else:
                    myotube_rgb[yd, xd] = color_1

    from skimage.io import imsave
    myotube_rgb_8 = img_as_ubyte(myotube_rgb)
    imsave(output_path, myotube_rgb_8)

## Main loop for prediction

def main():

    # Load config file
    with open(config_path, 'r') as f:
        config_dict = json.load(f)

    # Load classifier
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print ("Classification : ", device)
    
    checkpoint = torch.load(model_path, map_location=device)
    model = checkpoint["model"].to(device)
    model.eval()

    # Load Images + Masks + Centroids
    image_files, images, labels, all_label_ids, all_centroids = load_images_labels_centroids(parent_directory)

    all_predictions = []

    for i, img_file in enumerate(image_files):
        current_image = images[i]
        current_label = labels[i]
        label_ids = all_label_ids[i]
        centroids = all_centroids[i]  # liste (row, col)

        dataset = PredictionDataset(
            image=current_image,
            labels=current_label,
            label_ids=label_ids,
            centroids=centroids,
            half_patch_size=half_patch_size,
            device=device,
            config_dict=config_dict
        )
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)

        image_preds = []
        with torch.no_grad():
            for batch_tensor in dataloader:
                if batch_tensor is None or batch_tensor.size(0) == 0:
                    continue
                outputs = model(batch_tensor)
                preds = outputs.argmax(dim=1)
                image_preds.extend(preds.cpu().numpy())
        
        # Save prediction image
        if save_prediction == True :
            image_name = os.path.splitext(img_file)[0]
            output_path = os.path.join(predictions_output_dir, f"{image_name}_prediction.tif")
            save_colored_predictions_downsample(
                labels=current_label,
                predictions=image_preds,
                used_labels=label_ids,
                myotube_image=current_image,
                output_path=output_path,
                factor = downsampling_factor
            )
        
        all_predictions.append(image_preds)

    ## Export results
    
    # List to store statistics for each image
    prediction_data = []

    for i, predictions in enumerate(all_predictions):
        pred_array = np.array(predictions)
        
        # Count 0 / 1
        num_ones = np.sum(pred_array == 1)
        num_zeros = np.sum(pred_array == 0)
        total_labels = len(pred_array)

        if total_labels > 0:
            fusion_index = (num_ones / total_labels) * 100.0
        else:
            fusion_index = 0.0
        
        image_name = os.path.splitext(os.path.basename(image_files[i]))[0]
        
        prediction_data.append({
            "Image Name": image_name,
            "Total Number of Nuclei": total_labels,
            "Nuclei In": num_ones,
            "Nuclei Out": num_zeros,
            "Fusion Index (%)": fusion_index
        })

    predictions_df = pd.DataFrame(prediction_data)
    output_file_path = os.path.join(parent_directory, "Fusion_index.xlsx")
    predictions_df.to_excel(output_file_path, index=False)

    print(f"\nDataFrame exported to Excel file: {output_file_path}")
    print("Classification Done")

    return all_predictions, predictions_df
        
if __name__ == "__main__":
    main()