# Test on VoxelProcessing.py functions

This notebook is intended to test if the functions implemented for voxel preprocessing are robust to pathological cases.

In [None]:
import numpy as np
import pandas as pad
from skimage.io import imread, imsave
import napari
import importlib
import os
from tqdm import tqdm
from scipy import ndimage
import warnings
import sys
from skimage.measure import regionprops

In [None]:
# os.chdir('src')
os.getcwd()

## 1. Remove Unconnected regions

Introduce 2 warnings:

1. In case a region which is larger than some % of the largest one is there.
2. In case a region of any size which is very close to the largest one is removed.

In [None]:
def custom_showwarning(message, category, filename, lineno, file=None, line=None):
    # Only print the warning message and category
    print(f"{category.__name__}: {message}")

In [None]:
def remove_unconnected_regions(labeled_img, pad_width=10, warning_size_threshold=0.25):
    """
    Removes regions of labels that are not connected to the main cell body.

    Parameters:
    -----------
    labeled_img: (np.array, 3D)
        A 3D labeled image where the background has a label of 0 and cells are labeled with 
        consecutive integers starting from 1.

    padding: (int, optional, default=1)
        The number of pixels to pad the labeled image a# Create a test image on napari
test_labels = np.zeros((100, 100, 100), dtype=np.int8)
viewer.add_labels(test_labels)
    --------
    filtered_labeled_img: (np.array, 3D)
        A 3D labeled image with unconnected regions removed.
    """
    warnings.showwarning = custom_showwarning

    unique_labels = np.unique(labeled_img)
    filtered_labeled_img = labeled_img.copy()

    for label in tqdm(unique_labels, desc='Removing unconnected regions'):
        if label == 0:
            continue
        binary_mask = (filtered_labeled_img == label).astype(np.uint8)

        # Label connected regions
        labeled_mask, num_features = ndimage.label(binary_mask)

        # Remove unconnected regions
        if num_features > 1:
            region_sizes = ndimage.sum(binary_mask, labeled_mask, range(num_features + 1))
            relative_region_sizes = region_sizes / np.max(region_sizes)
            num_regions_over_threshold = np.sum(relative_region_sizes > warning_size_threshold) - 1
            if num_regions_over_threshold: 
                warnings.warn(f'Removing {num_regions_over_threshold} large regions with label {label} with threshold set at {warning_size_threshold}.')
            largest_region_label = np.argmax(region_sizes[1:]) + 1
            filtered_region = (labeled_mask == largest_region_label).astype(np.int8) * largest_region_label
            filtered_labeled_img[labeled_mask != filtered_region] = 0


    return filtered_labeled_img

In [None]:
viewer = napari.Viewer()

In [None]:
# Create a test image on napari
test_labels = np.zeros((100, 100, 100), dtype=np.int8)
viewer.add_labels(test_labels)

In [None]:
clean_test_labels = remove_unconnected_regions(test_labels)
viewer.add_labels(clean_test_labels)

Try on a real sample

In [None]:
real_labels = imread('/nas/groups/iber/Users/Federico_Carrara/3d_tissues_preprocessing_and_segmentation/lung_segmentation/results/new_sample/central_crop/lung_new_sample_b_curated_segmentation_central_crop_relabel_seq.tif')
viewer.add_labels(real_labels)

clean_real_labels = remove_unconnected_regions(real_labels)
viewer.add_labels(clean_real_labels)

## 2. Check filtering function

In [None]:
def remove_labels_touching_edges(labeled_img):
    """
    Remove all labels that touch the edge of the image.

    Parameters:
    -----------
    labeled_img: (np.array, 3D)
        The input 3D labeled image, where the background has a label of 0 and other objects have 
        positive integer labels.

    Returns:
    --------
    filtered_labeled_img: (np.array, 3D)
        The filtered 3D labeled image, where labels touching the edges have been removed.
    """
    # Create a copy of the labeled image to store the filtered output
    filtered_labeled_img = labeled_img.copy()

    # Get the dimensions of the image
    image_shape = np.array(labeled_img.shape)

    # Calculate the regions and their properties
    regions = regionprops(labeled_img)

    # Iterate through the regions
    for region in regions:
        # Get the bounding box of the region

        min_slice, minr, minc, max_slice, maxr, maxc = region.bbox

        # Check if the bounding box touches the edge of the image
        if (min_slice == 0 or minr == 0 or minc == 0 or max_slice == image_shape[0] or maxr == image_shape[1] or maxc == image_shape[2]):
            # If so, remove the label from the filtered labeled image
            filtered_labeled_img[labeled_img == region.label] = 0

    return filtered_labeled_img

In [None]:
viewer = napari.Viewer()

In [None]:
real_labels = imread('/nas/groups/iber/Users/Federico_Carrara/3d_tissues_preprocessing_and_segmentation/lung_segmentation/results/new_sample/central_crop/lung_new_sample_b_curated_segmentation_central_crop_relabel_seq.tif')
viewer.add_labels(real_labels)

filtered_labels = remove_labels_touching_edges(real_labels)
viewer.add_labels(filtered_labels)

## 3. Filter cells without complete neighborhood

In [None]:
def get_cell_neighbors(labeled_img: np.array, cell_id: int):
    """
    Get all the neighbors of a given cell. Two cells are considered neighborhs if 
    a subset of their surfaces are directly touching.
    
    Parameters:
    -----------

    labeled_img: (np.array, 3D)
        The tirangular meshes of the cell in the standard trimesh format

    cell_id: (int)
        The id of the cell for which we want to find the neighbors

    Returns
    -------
    neighbors_lst: (list of int)
        Return the ids of the neighbors in a list
    """

    #Get the voxels of the cell
    binary_img = labeled_img == cell_id

    #Expand the volume of the cell by 2 voxels in each direction
    expanded_cell_voxels = ndimage.binary_dilation(binary_img, iterations=2)

    #Find the voxels that are directly in contact with the surface of the cell
    cell_surface_voxels = expanded_cell_voxels ^ binary_img

    #Get the labels of the neighbors
    neighbors_lst, neighbors_counts = np.unique(labeled_img[cell_surface_voxels], return_counts=True)

    #Remove the label of the cell itself, and the label of the background from the neighbors list
    neighbors_lst = neighbors_lst[(neighbors_lst != cell_id) & (neighbors_lst != 0)]

    #Compute average neighbors counts of non zero labels
    complete_neighborhood = False
    max_neighbors_count = np.max(neighbors_counts[1:])
    if neighbors_counts[0] <= max_neighbors_count:
        complete_neighborhood = True

    return neighbors_lst.tolist(), complete_neighborhood

In [None]:
viewer = napari.Viewer()

In [None]:
real_labels = imread('/nas/groups/iber/Users/Federico_Carrara/3d_tissues_preprocessing_and_segmentation/lung_segmentation/results/new_sample/central_crop/lung_new_sample_b_curated_segmentation_central_crop_relabel_seq.tif')
viewer.add_labels(real_labels)

In [None]:
complete_cells = []
for label in tqdm(np.unique(real_labels)):
    if label == 0: continue
    _, is_complete = get_cell_neighbors(real_labels, label)
    if is_complete: complete_cells.append(label)

In [None]:
filtered_labels = real_labels.copy() 
for cell_id in np.unique(filtered_labels):
    if cell_id not in complete_cells:
        filtered_labels[filtered_labels == cell_id] = 0

viewer.add_labels(filtered_labels)

## 4. Visualize images to undertstand what  is needed for filtering

In [None]:
PATH_TO_IMGS = '/nas/groups/iber/Users/Federico_Carrara/create_meshes/data/curated_labels/'
file_names = os.listdir(PATH_TO_IMGS)

viewer = napari.Viewer()

for file in file_names:
    curr_img = imread(os.path.join(PATH_TO_IMGS, file))
    viewer.add_labels(curr_img, name=file.replace('.tif', ''))