In [1]:
from tqdm.auto import tqdm
import glob

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

In [2]:
import requests
import zipfile
import shutil
import os

import sys
# Configuration to avoid generating __pycache__ and limit traceback for cleaner exceptions
sys.dont_write_bytecode = True
sys.tracebacklimit = 0

def download_and_setup_model(option):
    python_file_url = 'https://raw.githubusercontent.com/ajinkya-kulkarni/PyBlendPatches/main/modules.py'
    python_filename = 'modules.py'
    zip_file_url = ''
    zip_filename = ''
    model_folder_path = ''

    if option == 'stardist':
        zip_file_url = 'https://github.com/ajinkya-kulkarni/TrainedModels/raw/main/StarDist/stardist_model.zip'
        zip_filename = 'StarDistTrainedModel.zip'
        model_folder_path = 'stardist_model'
    elif option == 'cellpose':
        zip_file_url = 'https://github.com/ajinkya-kulkarni/TrainedModels/raw/main/Cellpose/cellpose_model.zip'
        zip_filename = 'CellposeTrainedModel.zip'
        model_folder_path = 'cellpose_model.941139'
    else:
        raise ValueError(f"Invalid option: {option}")

    # Download and save the Python file from the URL
    if os.path.exists(python_filename):
        os.remove(python_filename)
    with requests.get(python_file_url, stream=True) as r:
        r.raise_for_status()
        total_size_in_bytes = int(r.headers.get('content-length', 0))
        block_size = 1024
        progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True, 
                            desc = f'Downloading {python_filename}')
        with open(python_filename, 'wb') as f:
            for data in r.iter_content(block_size):
                progress_bar.update(len(data))
                f.write(data)
        progress_bar.close()

    # Download and extract the zip file
    if os.path.exists(zip_filename):
        os.remove(zip_filename)
    if os.path.isdir(model_folder_path):
        shutil.rmtree(model_folder_path)
    if os.path.isfile(model_folder_path):
        os.remove(model_folder_path)
    with requests.get(zip_file_url, stream=True) as r:
        r.raise_for_status()
        total_size_in_bytes = int(r.headers.get('content-length', 0))
        block_size = 1024
        progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True, 
                            desc = f'Downloading {option} model')
        with open(zip_filename, 'wb') as f:
            for data in r.iter_content(block_size):
                progress_bar.update(len(data))
                f.write(data)
        progress_bar.close()

    with zipfile.ZipFile(zip_filename, 'r') as zip_ref:
        zip_ref.extractall('.')
    os.remove(zip_filename)
    
    return model_folder_path

In [3]:
model_path = download_and_setup_model('cellpose')

Downloading modules.py:   0%|          | 0.00/4.32k [00:00<?, ?iB/s]

Downloading cellpose model:   0%|          | 0.00/24.7M [00:00<?, ?iB/s]

In [4]:
from modules import *

In [5]:
lbl_cmap = random_label_cmap()

In [6]:
from cellpose import models

def predict_mask_from_image_cellpose(normalized_img_patch, gpu_usage = False):

    model = models.CellposeModel(gpu = gpu_usage, pretrained_model = model_path)
    
    channels = [[0, 0]]

    """
    Predicts a segmentation mask from an image patch using a given model.
    Args:
    normalized_img_patch (numpy array): The normalized image patch for prediction.
    model (object): The segmentation model used for predicting the mask.
    channels (list): A list of channels to be used in the prediction.

    Returns:
    numpy array: The predicted segmentation mask.
    """

    mask, flow, style = model.eval(normalized_img_patch, diameter=None, channels=channels)

    # Check if the mask is empty (all zeros)
    if np.any(mask):
        # Return the predicted mask if it's not empty
        return mask.astype('uint16')
    else:
        # Return an array of zeros if the mask is empty
        return np.zeros(normalized_img_patch.shape, dtype=np.uint16)

In [7]:
folder_path = '/home/ajinkya/Desktop/PyOrganoidAnalysis/DataSetPreparation/Test_Set'
data_dir = os.path.join(folder_path, 'images', '*.tif')

predict_images_path = sorted([f for f in glob.glob(data_dir)])

In [8]:
window_size = 800
overlap = int(0.5 * window_size)

In [9]:
images_array = []
labels_array = []

for image_path in tqdm(predict_images_path, desc='Reading images and predicting', leave = True):
    
    normalized_img = read_image_as_grayscale_then_MinMax_normalize(image_path)
        
    ###########################################################################################

    patches, window_coords = patchify(normalized_img, window_size, overlap)
    
    ###########################################################################################
    
    predicted_labels = []

    for patch in tqdm(patches, desc = 'Predicting on patches', leave = False):
    
        label = predict_mask_from_image_cellpose(patch, gpu_usage = True)
    
        smoothed_label = smooth_segmented_labels(label)
    
        predicted_labels.append(smoothed_label)

    ###########################################################################################
    
    border_cleaned_predicted_labels = []
    
    for patch, patch_coords in zip(predicted_labels, window_coords):
        cleaned_patch = remove_border_labels(patch, patch_coords, normalized_img)
        border_cleaned_predicted_labels.append(cleaned_patch)
        
    ###########################################################################################
    
    region_info_list = compile_label_info(np.array(border_cleaned_predicted_labels), window_coords)
    
    ###########################################################################################
    
    # First, extract the bounding boxes from each region in the region_info_list
    # This creates an array of bounding boxes where each box is defined by [x_min, y_min, x_max, y_max]
    boxes = np.array([region['global_bbox'] for region in region_info_list])
    
    # Apply the Non-Maximum Suppression (NMS) function to these boxes.
    # NMS will analyze these bounding boxes and return the indices of boxes that should be kept
    # based on the overlap threshold of 0.5. Boxes that overlap more than this threshold with a larger box
    # will be filtered out.
    nms_indices = non_maximum_suppression(boxes, overlapThresh=0.5)
    
    # Using the indices obtained from NMS, construct the final list of regions.
    # This list will only include regions whose bounding boxes were selected by the NMS process,
    # effectively filtering out regions with significantly overlapping bounding boxes.
    nms_region_info_list = [region_info_list[i] for i in nms_indices]
    
    # final_region_info_list now contains the refined list of regions after applying NMS.
    # These are the regions that are considered significant based on their size and the lack of substantial
    # overlap with larger regions.
    
    ###########################################################################################
    
    canvas = place_labels_on_canvas(normalized_img, nms_region_info_list)

    mask_filename = os.path.join(folder_path, 'AlgorithmPredictions', 
                                 os.path.splitext(os.path.basename(image_path))[0] + '_CellposeMask.tif')
    
    if os.path.exists(mask_filename):
        os.remove(mask_filename)
    canvas_image = Image.fromarray(canvas)
    canvas_image.save(mask_filename, format='TIFF')
    
    ###########################################################################################
    
    images_array.append(np.asarray(normalized_img))
    labels_array.append(canvas)

Reading images and predicting:   0%|          | 0/20 [00:00<?, ?it/s]

Predicting on patches:   0%|          | 0/6 [00:00<?, ?it/s]

Predicting on patches:   0%|          | 0/6 [00:00<?, ?it/s]

Predicting on patches:   0%|          | 0/6 [00:00<?, ?it/s]

Predicting on patches:   0%|          | 0/6 [00:00<?, ?it/s]

Predicting on patches:   0%|          | 0/6 [00:00<?, ?it/s]

Predicting on patches:   0%|          | 0/6 [00:00<?, ?it/s]

Predicting on patches:   0%|          | 0/6 [00:00<?, ?it/s]

Predicting on patches:   0%|          | 0/6 [00:00<?, ?it/s]

Predicting on patches:   0%|          | 0/6 [00:00<?, ?it/s]

Predicting on patches:   0%|          | 0/6 [00:00<?, ?it/s]

Predicting on patches:   0%|          | 0/6 [00:00<?, ?it/s]

Predicting on patches:   0%|          | 0/6 [00:00<?, ?it/s]

Predicting on patches:   0%|          | 0/6 [00:00<?, ?it/s]

Predicting on patches:   0%|          | 0/6 [00:00<?, ?it/s]

Predicting on patches:   0%|          | 0/6 [00:00<?, ?it/s]

Predicting on patches:   0%|          | 0/6 [00:00<?, ?it/s]

Predicting on patches:   0%|          | 0/6 [00:00<?, ?it/s]

Predicting on patches:   0%|          | 0/6 [00:00<?, ?it/s]

Predicting on patches:   0%|          | 0/6 [00:00<?, ?it/s]

Predicting on patches:   0%|          | 0/6 [00:00<?, ?it/s]