In [None]:
from skimage.io import imread
from skimage.measure import regionprops
from skimage.transform import rescale
from scipy import ndimage
import numpy as np
import napari
import os
import sys
import matplotlib.pyplot as plt
from tqdm import tqdm
import trimesh as tm
from scipy.spatial.transform import Rotation
from typing import Union, Optional, Iterable, List, Tuple, Dict
from morphosamplers.sampler import (
    generate_2d_grid,
    generate_3d_grid,
    place_sampling_grids,
    sample_volume_at_coordinates,
)
from napari_process_points_and_surfaces import label_to_surface
from scipy.linalg import polar
from scipy.spatial.distance import euclidean
from collections import defaultdict

## Visualize vectors and directions

In [None]:
labeled_img = imread('../curated_labels/intestine_sample2_b_curated_segmentation_relabel_seq.tif')
cell_img = (labeled_img == 301).astype(np.uint16)
cell_img = np.einsum('kij->ijk', cell_img)

In [None]:
props = regionprops(cell_img)[0]
cell_centroid = props.centroid
cell_length = int(props.axis_major_length)

In [None]:
# Load the corresponding mesh
cell_mesh = tm.load_mesh('../outputs/outputs_v2/output_intestine_sample2_b_curated_segmentation_relabel_seq_s_10_e_6_d_8/cell_meshes/cell_301.stl')

# Compute the principal axis
eigen_values, eigen_vectors = tm.inertia.principal_axis(cell_mesh.moment_inertia)

# Get the index of the smallest eigen value
smallest_eigen_value_idx = np.argmin(np.abs(eigen_values))
greatest_eigen_value_idx = np.argmax(np.abs(eigen_values))

# Get the corresponding eigen vector 
principal_axis = eigen_vectors[smallest_eigen_value_idx]
principal_axis = principal_axis / np.array([0.325, 0.325, 0.1625])
principal_axis = principal_axis / np.linalg.norm(principal_axis)

print(principal_axis)
print(np.linalg.norm(principal_axis))

In [None]:
# surface = label_to_surface(cell_img)
# points, faces = surface[0], surface[1]
# points = (points)
# cell_mesh = tm.Trimesh(points, faces)
# cell_mesh = tm.smoothing.filter_mut_dif_laplacian(cell_mesh, iterations=10, lamb=0.5)

# # Compute the principal axis
# eigen_values, eigen_vectors = tm.inertia.principal_axis(cell_mesh.moment_inertia)

# # Get the index of the smallest eigen value
# smallest_eigen_value_idx = np.argmin(np.abs(eigen_values))
# greatest_eigen_value_idx = np.argmax(np.abs(eigen_values))

# # Get the corresponding eigen vector 
# principal_axis = eigen_vectors[smallest_eigen_value_idx]
# principal_axis = principal_axis

# print(principal_axis)
# print(np.linalg.norm(principal_axis))

In [None]:
slicing_dir = np.asarray([principal_axis*i for i in np.linspace(-cell_length, cell_length, 20)]) + cell_centroid
print(cell_centroid, slicing_dir)

In [None]:
viewer = napari.Viewer()
viewer.add_labels(cell_img)
viewer.add_surface((cell_mesh.vertices, cell_mesh.faces))
viewer.add_points(slicing_dir, size=1)

## Extract 3D volume

In [None]:
grid_shape = [(2*cell_length + 1), 201, 201]
grid_center_point = cell_centroid
normal_vector = np.array([1, 0, 0])

# Compute the rotation
third_vector = np.cross(principal_axis, normal_vector)

rot_matrix = np.column_stack(
    [principal_axis, normal_vector, third_vector]
)

rotations = [Rotation.from_matrix(rot_matrix)]

# If asked by the users, generate a grid and place it in order to check for potential errors
grid = generate_3d_grid(grid_shape)
sampling_coordinates = place_sampling_grids(
    grid, grid_center_point, rotations
)

sampled_vol = sample_volume_at_coordinates(
    labeled_img,
    sampling_coordinates,
    interpolation_order=0,
)

In [None]:
viewer = napari.Viewer()
viewer.add_points(sampling_coordinates.reshape((-1, 3)), size=0.5)
viewer.add_labels(labeled_img)
viewer.add_labels(sampled_vol)

## Extract 2D slices

In [None]:
def _get_slices_along_direction(
    labeled_img: np.ndarray,
    slicing_dir: Iterable[float],
    centroid: Iterable[float],
    height: float,    
    slice_size: Optional[int] = 200,
    num_slices: Optional[int] = 10,
    debugging_mode: Optional[bool] = False,
) -> Tuple[List[np.ndarray[int]], Tuple[List[np.ndarray[float]], np.ndarray[float]]]:
    
    # Define the centers of the sampling grids
    slicing_dir = np.asarray(slicing_dir)
    slicing_centers = np.asarray([
        slicing_dir * i 
        for i in np.linspace(-height, height, num_slices)
    ]) + centroid

    # Compute the rotation matrix
    # Get a normal vector with Gramm-Schmidt method
    random_vector = np.random.rand(3)
    normal_vector = random_vector - np.dot(random_vector, slicing_dir) * slicing_dir
    normal_unit_vector = normal_vector / np.linalg.norm(normal_vector)
    # Get third vector to form orthonormal basis
    third_vector = np.cross(slicing_dir, normal_unit_vector)
    # Arrange vectors in rotation matrix
    rot_matrix = np.column_stack(
        (normal_unit_vector, third_vector, slicing_dir)
    )
    rotations = [Rotation.from_matrix(rot_matrix)]

    # DEBUGGING
    if debugging_mode:
        viewer.add_points(slicing_centers)
        third_vecs = np.asarray([
            third_vector * i 
            for i in np.linspace(-height, height, num_slices)
        ]) + centroid
        viewer.add_points(third_vecs)
        norm_vecs = np.asarray([
            normal_vector * i 
            for i in np.linspace(-height, height, num_slices)
        ]) + centroid
        viewer.add_points(norm_vecs)

    # Save specifiers of the different grids
    grid_specs = (slicing_centers, slicing_dir)

    labeled_slices = []
    for center in slicing_centers:
        # Generate a grid with the requested size
        grid_center_point = center
        grid_shape = [slice_size + 1] * 2
        grid = generate_2d_grid(grid_shape)

        # Rotate and translate the grid
        sampling_coordinates = place_sampling_grids(
            grid, grid_center_point, rotations
        )
        # viewer.add_points(sampling_coordinates.reshape((-1, 3)), size=0.25)

        # Sample values from the grid
        sampled_plane = sample_volume_at_coordinates(
            labeled_img,
            sampling_coordinates,
            interpolation_order=0,
        )

        labeled_slices.append(sampled_plane)

    return labeled_slices, grid_specs

In [None]:
#------------------------------------------------------------------------------------------------------------
def _compute_2D_area_along_direction(
		labeled_slice: np.ndarray[int], 
		cell_label: np.ndarray[int],
		pixel_size: Iterable[float]
) -> float:
	
	binary_slice = (labeled_slice == cell_label).astype(np.uint16)
	if np.any(binary_slice):
		pixel_count = np.sum(binary_slice)
		pixel_area = pixel_size[0] * pixel_size[1]
		area = pixel_count * pixel_area
		return area
	else:
		return 0.0
#------------------------------------------------------------------------------------------------------------



#------------------------------------------------------------------------------------------------------------
def _compute_2D_neighbors_along_direction(
	labeled_slice: np.ndarray[int], 
	cell_label: np.ndarray[int],
	background_threshold: float = 0.1
) -> List[int]:

	#Get the pixels of the cell
	binary_slice = labeled_slice == cell_label

	# Check if cell is present in the slice
	if not np.any(binary_slice):
		return [-1]

	#Expand the volume of the cell by 2 voxels in each direction
	expanded_cell_voxels = ndimage.binary_dilation(binary_slice, iterations=2)
		
	#Find the voxels that are directly in contact with the surface of the cell
	cell_surface_voxels = expanded_cell_voxels ^ binary_slice

	#Get the labels of the neighbors
	neighbors, counts = np.unique(labeled_slice[cell_surface_voxels], return_counts=True)

	#Check if the label is touching the background above a certain threshold
	# print(f'Cell {cell_label}: {neighbors}, {counts}')
	if (0 in neighbors) and (counts[0] > np.sum(counts) * background_threshold):
		return [-1]
	else:
		#Remove the label of the cell itself, and the label of the background from the neighbors list
		neighbors = neighbors[(neighbors != cell_label) & (neighbors != 0)]
		return list(neighbors)

In [None]:
def compute_2D_statistics_along_axis(
    labeled_img: np.ndarray[int],
    meshes: Iterable[tm.base.Trimesh],
    exclude_labels: Iterable[int],
    voxel_size: Iterable[float],
    slice_ext: int = 200,
    number_slices: int = 10, 
    remove_empty: Optional[bool] = False
) -> None:
    
    if np.any(slice_ext > np.asarray(labeled_img.shape)):
        slice_ext = np.min(labeled_img.shape)

    # Iterate over the cells
    label_ids = np.unique(labeled_img)

    neighbors_dict = {}
    num_neighbors_dict = {}
    areas_dict = {}
    slices_dict = {}
    for i, label_id in tqdm(enumerate(label_ids[1:]), desc='Computing cell 2D statistics along apical-basal axis', total=(len(label_ids)-1)):
        if label_id in exclude_labels:
            neighbors_dict[label_id], areas_dict[label_id] = [[]], []
            slices_dict[label_id] = ()
        else:
            # Compute principal axis, axis length and centroid coordinates 
            cell_mesh = meshes[i]
            eigen_values, eigen_vectors = tm.inertia.principal_axis(cell_mesh.moment_inertia)
            smallest_eigen_value_idx = np.argmin(np.abs(eigen_values))
            principal_axis = eigen_vectors[smallest_eigen_value_idx]
            principal_axis = np.asarray(principal_axis) * np.asarray(voxel_size)
            principal_axis = principal_axis / np.linalg.norm(principal_axis)

            binary_img = (labeled_img == label_id).astype(np.uint8)
            props = regionprops(binary_img)[0]
            cell_centroid = props.centroid
            cell_length = int(props.axis_major_length)

            # Get slices along principal axis direction
            labeled_slices, slices_specs = _get_slices_along_direction(
                labeled_img=labeled_img,
                slicing_dir=principal_axis,
                centroid=cell_centroid,
                height=cell_length,
                slice_size=slice_ext,
                num_slices=number_slices
            )

            # Iterate across slices to compute the statistics
            cell_areas = []
            cell_neighbors = []
            for labeled_slice in labeled_slices:
                area_slice = _compute_2D_area_along_direction(
                    labeled_slice=labeled_slice,
                    cell_label=label_id,
                    pixel_size=[0.325, 0.325, 0.25]
                )
                cell_areas.append(area_slice)
                neighbors_slice = _compute_2D_neighbors_along_direction(
                    labeled_slice=labeled_slice,
                    cell_label=label_id
                )
                cell_neighbors.append(neighbors_slice)
            
            if remove_empty:
                # Post-process results to remove meaningless values
                to_remove = [neighs == [-1] for neighs in cell_neighbors]
                cell_neighbors = [neighs for neighs, flag in zip(cell_neighbors, to_remove) if not flag]
                cell_areas = [area for area, flag in zip(cell_areas, to_remove) if not flag]
                new_slices_specs = [item for item, flag in zip(slices_specs[0], to_remove) if not flag]
                slices_specs = (new_slices_specs, slices_specs[1])

            neighbors_dict[label_id] = cell_neighbors
            num_neighbors_dict[label_id] = [len(neighs) for neighs in cell_neighbors]
            areas_dict[label_id] = cell_areas
            slices_dict[label_id] = slices_specs

    return neighbors_dict, num_neighbors_dict, areas_dict, slices_dict

### Example

In [None]:
labeled_slices = []
for point in slicing_dir:
    grid_shape = [201, 201]
    grid_center_point = point
    normal_vector = np.array([1, 0, 0])

    # Compute the rotation
    third_vector = np.cross(principal_axis, normal_vector)

    rot_matrix = np.column_stack(
        [third_vector, normal_vector, principal_axis]
    )

    rotations = [Rotation.from_matrix(rot_matrix)]

    # If asked by the users, generate a grid and place it in order to check for potential errors
    grid = generate_2d_grid(grid_shape)
    sampling_coordinates = place_sampling_grids(
        grid, grid_center_point, rotations
    )

    sampled_plane = sample_volume_at_coordinates(
        labeled_img,
        sampling_coordinates,
        interpolation_order=1,
    )

    labeled_slices.append(sampled_plane)

In [None]:
viewer = napari.Viewer()
for labeled_slice in labeled_slices:
    viewer.add_labels(labeled_slice)

## Test on cubes

In [None]:
import src.tests.CubeLatticeTest as cube

In [None]:
test_img = cube.generate_cube_lattice_image(
    100, 100, 100, 20, 3, 3, 1, 0
)

In [None]:
viewer = napari.Viewer()

In [None]:
viewer.add_labels(test_img)

In [None]:
from src.statistics_collection.GenMeshes import convert_labels_to_meshes

meshes_dict = convert_labels_to_meshes(
    img=test_img,
    voxel_resolution=[0.1, 0.1, 0.1],
    smoothing_iterations=0,
    output_directory='./src/tests/output',
    pad_width=5
)

In [None]:
results = compute_2D_statistics_along_axis(
    labeled_img=test_img,
    meshes=list(meshes_dict.values()),
    exclude_labels=[],
    remove_empty=False
)

## Test on a single cell

In [None]:
intestine_img = imread('../curated_labels/intestine_sample2_b_curated_segmentation_relabel_seq.tif')
intestine_img = np.einsum('kij->ijk', intestine_img)
cell_mesh = tm.load_mesh('../outputs/outputs_v2/output_intestine_sample2_b_curated_segmentation_relabel_seq_s_10_e_6_d_8/cell_meshes/cell_301.stl')
cell_img = (intestine_img == 301).astype(np.uint8)

props = regionprops(cell_img)[0]
cell_centroid = props.centroid
cell_length = int(props.axis_major_length)

# Compute the principal axis
eigen_values, eigen_vectors = tm.inertia.principal_axis(cell_mesh.moment_inertia)

# Get the index of the smallest eigen value
smallest_eigen_value_idx = np.argmin(np.abs(eigen_values))
greatest_eigen_value_idx = np.argmax(np.abs(eigen_values))

# Get the corresponding eigen vector 
principal_axis = eigen_vectors[smallest_eigen_value_idx]
rescaled_principal_axis = principal_axis / np.asarray([0.325, 0.325, 0.25])
rescaled_normalized_principal_axis = rescaled_principal_axis / np.linalg.norm(rescaled_principal_axis)

In [None]:
rescaled_normalized_principal_axis

In [None]:
viewer = napari.Viewer()
viewer.add_labels(cell_img)
viewer.add_surface((cell_mesh.vertices, cell_mesh.faces))

In [None]:
_, _ = _get_slices_along_direction(
    labeled_img=cell_img,
    slicing_dir=rescaled_normalized_principal_axis,
    centroid=cell_centroid,
    height=cell_length,
    num_slices=11
)

## Test on Intestine sample

In [None]:
# Load labeled image and meshes
intestine_img = imread('../outputs/outputs_v2/output_intestine_sample2_b_curated_segmentation_relabel_seq_s_10_e_6_d_8/processed_labels.tif')
# intestine_img = np.einsum('kij->ijk', intestine_img)
meshes_dir = '../outputs/outputs_v2/output_intestine_sample2_b_curated_segmentation_relabel_seq_s_10_e_6_d_8/cell_meshes/'
file_names = os.listdir(meshes_dir)
sorted_file_names = sorted(file_names, key=lambda x: int(x.split("_")[1].split(".")[0]))
meshes_files = [os.path.join(meshes_dir, mesh_file) for mesh_file in sorted_file_names]
meshes = []
for mesh_file in tqdm(meshes_files):
    meshes.append(tm.load_mesh(mesh_file))

with open('../outputs/outputs_v2/output_intestine_sample2_b_curated_segmentation_relabel_seq_s_10_e_6_d_8/cut_cells_labels.txt', 'r') as file:
    to_exclude = [int(float(line.strip())) for line in file]

In [None]:
results = compute_2D_statistics_along_axis(
    labeled_img=intestine_img,
    meshes=meshes,
    exclude_labels=to_exclude,
    remove_empty=True,
    voxel_size=[0.325, 0.325, 0.25],
    number_slices=25
)

In [None]:
count = 0
nums = []
for cell_neighs in results[1].values():
    for num in cell_neighs:
        nums.append(num)
        count += 1
print(sum(nums)/count)
print(np.unique(nums, return_counts=True))

In [None]:
results[3]

## Test on subgroup of intestine cells

In [None]:
intestine_img = imread('../outputs/outputs_v2/output_intestine_sample2_b_curated_segmentation_relabel_seq_s_10_e_6_d_8/processed_labels.tif')
meshes_dir = '../outputs/outputs_v2/output_intestine_sample2_b_curated_segmentation_relabel_seq_s_10_e_6_d_8/cell_meshes/'
file_names = os.listdir(meshes_dir)
sorted_file_names = sorted(file_names, key=lambda x: int(x.split("_")[1].split(".")[0]))
meshes_files = [os.path.join(meshes_dir, mesh_file) for mesh_file in sorted_file_names]
meshes = []
for mesh_file in tqdm(meshes_files):
    meshes.append(tm.load_mesh(mesh_file))

In [None]:
viewer = napari.Viewer()

In [None]:
idxs = [71, 102, 109, 301, 127, 125, 137, 155, 118]
cell_group = intestine_img * np.isin(intestine_img, idxs).astype(np.uint16)
meshes_group = [meshes[i-1] for i in range(len(meshes)) if i in idxs]
viewer.add_labels(cell_group)

In [None]:
results = compute_2D_statistics_along_axis(
    labeled_img=cell_group,
    meshes=meshes_group,
    exclude_labels=[],
    voxel_size=[0.325, 0.325, 0.25],
    remove_empty=False,
    number_slices=11
)

In [None]:
cell_mesh = meshes[300]
viewer.add_surface((cell_mesh.vertices, cell_mesh.faces))
grid_specs = results[3][301]
grid_shape = [201, 201]
grid = generate_2d_grid(grid_shape)
random_vector = np.random.rand(3)
normal_vector = random_vector - np.dot(random_vector, grid_specs[1]) * grid_specs[1]
normal_unit_vector = normal_vector / np.linalg.norm(normal_vector)
third_vector = np.cross(grid_specs[1], normal_vector)
rot_matrix = np.column_stack(
    [normal_unit_vector, third_vector, grid_specs[1]]
)
rotations = [Rotation.from_matrix(rot_matrix)]

# cell_img = (cell_group == 301).astype(np.uint8)
# props = regionprops(cell_img)[0]
# cell_centroid = props.centroid
# cell_length = int(props.axis_major_length)
# principal_axis = grid_specs[1] * [0.325, 0.325, 0.25]
# viewer.add_points((np.array(([0,0,0], grid_specs[1])) + cell_centroid).astype(int), size=1)
# viewer.add_points(cell_centroid)
# viewer.add_vectors((np.array(([0,0,0], principal_axis))+ cell_centroid) * cell_length, edge_color='pink')

# Rotate and translate the grid
for center in grid_specs[0]:
    sampling_coordinates = place_sampling_grids(
        grid, center, rotations
    )
    viewer.add_points(sampling_coordinates.reshape((-1, 3)), size=0.5)

# Aboav-Weaire Law

### Two cells case

In [None]:
intestine_img = imread('../outputs/outputs_v2/output_intestine_sample2_b_curated_segmentation_relabel_seq_s_10_e_6_d_8/processed_labels.tif')
idxs = [102, 301,]
two_cells = intestine_img * np.isin(intestine_img, idxs).astype(np.uint16)

In [None]:
viewer = napari.Viewer()
viewer.add_labels(two_cells)

In [None]:
cell_centroids, cell_lengths = [], []
principal_vectors = []
for idx in idxs:
    cell_img = (two_cells == idx).astype(np.uint16)

    # Principal axes and centroids
    props = regionprops(cell_img)[0]
    cell_centroids.append(props.centroid)
    cell_lengths.append(int(props.axis_major_length))

    # Load the corresponding mesh
    cell_mesh = tm.load_mesh(
        f'../outputs/outputs_v2/output_intestine_sample2_b_curated_segmentation_relabel_seq_s_10_e_6_d_8/cell_meshes/cell_{idx}.stl'
    )

    # Compute the principal axis
    eigen_values, eigen_vectors = tm.inertia.principal_axis(cell_mesh.moment_inertia)

    # Get the index of the smallest eigen value
    smallest_eigen_value_idx = np.argmin(np.abs(eigen_values))
    greatest_eigen_value_idx = np.argmax(np.abs(eigen_values))

    # Get the corresponding eigen vector 
    principal_axis = eigen_vectors[smallest_eigen_value_idx]
    principal_axis = principal_axis / np.array([0.325, 0.325, 0.1625])
    principal_vectors.append(principal_axis / np.linalg.norm(principal_axis))

#points along principal axes (for plotting)
principal_dirs = [
    np.asarray([vector * i + centroid for i in range(-100, 100, 5)]) 
    for vector, centroid in zip(principal_vectors, cell_centroids) 
]

In [None]:
# Plot principal axes
viewer.add_points(principal_dirs[0], name='Principal axis other cell')
viewer.add_points(principal_dirs[1], name='Principal axis main cell')

In [None]:
# Get plane passing by centroid of cell 301

### 1. Rotation matrix (also for other cell)
rotations = []
for i in range(2):
    random_vector = np.random.rand(3)
    normal_vector = random_vector - np.dot(random_vector, principal_vectors[i]) * principal_vectors[i]
    normal_unit_vector = normal_vector / np.linalg.norm(normal_vector)
    third_vector = np.cross(principal_vectors[i], normal_unit_vector)
    rot_matrix = np.column_stack(
        (normal_unit_vector, third_vector, principal_vectors[i])
    )
    rotations.append([Rotation.from_matrix(rot_matrix)])

### 2. Create grid
grid_shape = [200 + 1] * 2
grid = generate_2d_grid(grid_shape)

### 3. Place grid
placed_grid = place_sampling_grids(
    grid, cell_centroids[1], rotations[1]
)

grid_coords = placed_grid.reshape(-1, 3)

In [None]:
# Plot plane
viewer.add_points(grid_coords, size=1, name='Plane by main cell centroid')

In [None]:
grid_coords.shape[0] * principal_dirs[1].shape[0]

In [None]:
# Compute intersection between principal axis of other cell and the previous plane

### How to do?
### Both axis and plane are defined by coordinates. 
### Compute distances among all couples of points and take the minimum.
from scipy.spatial.distance import euclidean

min_distance = float('inf')
closest_point_grid = None
closest_point_axis = None

# Iterate through all combinations of points
for point1 in tqdm(grid_coords):
    for point2 in principal_dirs[1]:
        distance = euclidean(point1, point2)
        if distance < min_distance:
            min_distance = distance
            closest_point_grid = point1
            closest_point_axis = point2

In [None]:
### We may want to speed up the algorithm by just taking 'close enough' points
closest_point_grid = None
closest_point_axis = None
found = False

# Iterate through all combinations of points
for point1 in tqdm(grid_coords):
    for point2 in principal_dirs[0]:
        distance = euclidean(point1, point2)
        if distance < 5:
            closest_point_grid = point1
            closest_point_axis = point2
            found = True
            break
    if found: break

In [None]:
### Closest points are most likely in the middle of the grid.
### We may want to speed up the code by first looping from there.

closest_point_grid = None
closest_point_axis = None
found = False

# Iterate through all combinations of points
midpoint = grid_coords.shape[0]//2
for i in tqdm(range(midpoint)):
    for j in range(principal_dirs[0].shape[0]):
        point1, point2 = grid_coords[midpoint+i, :], principal_dirs[0][j, :]
        distance = euclidean(point1, point2)
        if distance < 5:
            closest_point_grid = point1
            closest_point_axis = point2
            found = True
            break
        point1 = grid_coords[midpoint-i, :]
        distance = euclidean(point1, point2)
        if distance < 5:
            closest_point_grid = point1
            closest_point_axis = point2
            found = True
            break
    if found: break

In [None]:
def find_closest(points_cloud1, points_cloud2, lower_threshold):
    closest_point1 = None
    closest_point2 = None
    found = False

    # Iterate through all combinations of points
    midpoint = points_cloud1.shape[0]//2
    for i in range(midpoint):
        for j in range(points_cloud2.shape[0]):
            point1, point2 = points_cloud1[midpoint+i, :], points_cloud2[j, :]
            distance = euclidean(point1, point2)
            if distance < lower_threshold:
                closest_point1 = point1
                closest_point2 = point2
                found = True
                break
            point1 = points_cloud1[midpoint-i, :]
            distance = euclidean(point1, point2)
            if distance < lower_threshold:
                closest_point1 = point1
                closest_point2 = point2
                found = True
                break
        if found: break
    
    return (closest_point1 + closest_point2) / 2

In [None]:
viewer.add_points(find_closest(grid_coords, principal_dirs[0], 2), size=10, name='intersection')

In [None]:
# Place other grid
other_centroid = closest_point_grid

other_placed_grid = place_sampling_grids(
    grid, other_centroid, rotations[0]
)

other_grid_coords = other_placed_grid.reshape(-1, 3)

In [None]:
# Plot plane
viewer.add_points(other_grid_coords, size=1, name='Plane by other cell centroid')

## Collect statistics for Aboav Law

Algorithm:

```
Iterate over cells in the image to compute their principal axes, centroid and length.

for each cell in the image:
    Place slicing grids at different height of the cell's principal axis.
    for each slicing grid: 
        Compute cell neighbors
        for each neighbor:
            Place the grid for that neighbor
            Compute its number of neighbors
```

```
Returns:
--------
aboav_law_dict: (Dict[int, Dict[int, List[float]]])
    {
        cell_id: {num_neighbors: List[avg_other_num_neighbors]}
    }
```

In [None]:
sys.path.append('/nas/groups/iber/Users/Federico_Carrara/Statistics_Collection/EpiStats/src/statistics_collection/')

In [None]:
from StatsUtils import _compute_2D_neighbors_along_direction, _compute_2D_area_along_direction

In [None]:
def _get_rotation(principal_vector):
    random_vector = np.random.rand(3)
    normal_vector = random_vector - np.dot(random_vector, principal_vector) * principal_vector
    normal_unit_vector = normal_vector / np.linalg.norm(normal_vector)
    third_vector = np.cross(principal_vector, normal_unit_vector)
    rot_matrix = np.column_stack(
        (normal_unit_vector, third_vector, principal_vector)
    )
    return [Rotation.from_matrix(rot_matrix)]

In [None]:
def _get_principal_axis(
        mesh: tm.base.Trimesh,
		scale: Iterable[float],
) -> np.ndarray[float]:
	"""
	Compute principal axis of from a mesh object.

	Parameters:
	-----------
		mesh: (tm.base.Trimesh)
			A Trimesh object in a N-dimensional space (N=2,3). 

		scale: (np.ndarray[float])
			An array of shape (N,) containing scale of mesh coordinate system in microns.
			NOTE: in the statistics collection pipeline the labeled image is to be 
			considered in a coordinate system of scale (1, 1, 1), whereas meshes are 
			generated in a different system with scale (voxel_size). 
			Therefore for this task we need to move the principal axis computed on the mesh
			into the labeled image coordinate system. 

	Returns:
	--------
		normalized_principal_axis: (np.ndarray[float])
			An array of shape (N,) representing the components of the rescaled and normalized 
			principal axis.
	"""
	eigen_values, eigen_vectors = tm.inertia.principal_axis(mesh.moment_inertia)
	smallest_eigen_value_idx = np.argmin(np.abs(eigen_values))
	principal_axis = eigen_vectors[smallest_eigen_value_idx]
	rescaled_principal_axis = np.asarray(principal_axis) / np.asarray(scale)
	normalized_principal_axis = rescaled_principal_axis / np.linalg.norm(rescaled_principal_axis)

	return normalized_principal_axis


In [None]:
def find_closest(points_cloud1, points_cloud2, lower_threshold):
    # Iterate through all combinations of points
    min_dist = float("inf")
    closest_points = None, None
    found = False
    midpoint = points_cloud1.shape[0]//2
    for i in range(midpoint):
        for j in range(points_cloud2.shape[0]):
            point1, point2 = points_cloud1[midpoint+i, :], points_cloud2[j, :]
            distance = euclidean(point1, point2)
            if distance < lower_threshold:
                closest_points = point1, point2
                found = True
                break
            if distance < min_dist:
                closest_points = point1, point2
                min_dist = distance
            point1 = points_cloud1[midpoint-i, :]
            distance = euclidean(point1, point2)
            if distance < lower_threshold:
                closest_points = point1, point2
                found = True
                break
            if distance < min_dist:
                closest_points = point1, point2
                min_dist = distance
        if found: break
    
    return sum(closest_points) / 2

In [None]:
def _compute_neighbors_of_neighbors_along_direction(
        labeled_img: np.ndarray[int],
        neighbors: Iterable[int],
        grid_coords: np.ndarray[float],
        principal_axes: Dict[int, np.ndarray[float]],
        principal_axis_pts: Dict[int, np.ndarray[float]],
        grid_to_place: np.ndarray[float],
        show_logs: Optional[bool] = False,
) -> List[int]:

    # iterate over neighbors from a slice to compute neighbors of neighbors
    neigh_num_neighbors = []
    for neighbor in neighbors:
        if show_logs:
            print(f"            Computing neigbors of neighbor {neighbor}")
        # get points on principal axis of neighboring cell
        neigh_principal_pts = principal_axis_pts[neighbor]
        neigh_principal_vector = principal_axes[neighbor]
        # get intersection between grid of main cell and points of neighbor principal axis
        neigh_center = find_closest(grid_coords, neigh_principal_pts, 20)
        # place grid and sample slice for neighbor
        neigh_rot = _get_rotation(neigh_principal_vector)
        neigh_placed_grid = place_sampling_grids(grid_to_place, neigh_center, neigh_rot)
        neigh_sampled_slice = sample_volume_at_coordinates(
            labeled_img,
            neigh_placed_grid,
            interpolation_order=0,
        )
        # compute number of neighbors for neighbor
        neigh_neighbors = _compute_2D_neighbors_along_direction(neigh_sampled_slice, neighbor, 0)
        # if any neighbor of main cell doesn't have complete neighborhood go to next slice
        if neigh_neighbors == [-1]:
            if show_logs:
                print("                Incomplete neighborhood, skipping current slice...")
            break
        else:
            if show_logs:
                print(f"                Neighbor {neighbor} has {len(neigh_neighbors)} neighbors")
            neigh_num_neighbors.append(len(neigh_neighbors))

    if len(neigh_num_neighbors) == len(neighbors):
        return neigh_num_neighbors

In [None]:
def _get_centroid_and_length(
        binary_img: np.ndarray[int]
) -> Tuple[np.ndarray[float], float]:
    """
    Compute the centroid and the major axis length of a binary image.

    Parameters:
    -----------
        binary_img: (np.ndarray[int])
			A binary image stored in a N-dimensional numpy array.
    
    Returns:
    --------
		centroid: (np.ndarray[int])
            An array of shape (N, ) storing the coordinates of the object in the binary image.

        length: (float)
            The length of the major axis of the object in the binary image.
    """ 
    props = regionprops(binary_img)[0]
    centroid = props.centroid
    _, _, min_z, _, _, max_z = props.bbox
    length = max_z - min_z

    return centroid, length

In [None]:
def _get_slices_along_direction(
    labeled_img: np.ndarray[int],
    slicing_dir: Iterable[float],
    centroid: Iterable[float],
    height: int,    
    grid_to_place: np.ndarray[int],
    num_slices: Optional[int] = 10,
) -> Tuple[List[np.ndarray[int]], Tuple[List[List[float]], List[float]]]:
    """
    Extract 2D slice along a given direction from a 3D labeled image.

    Parameters:
    -----------
        labeled_img: (np.ndarray[int])
            A 3D labeled image where the background has a label of 0 and cells are labeled with 
            consecutive integers starting from 1.
        
        slicing_dir: (Iterable[float])
            A triplet describing a unit vector in the labeled_img 3D coordinate system.

        centroid: (Iterable[float])
            A triplet associated to the coordinates of the centroid of the object at the center
            of the slices.
        
        height: (int)
            The height of the sliced volume above and below the centroid (i.e. total volum is 2*height).
        
        slice_size: (Optional[int], default=200)
            The size of the each side (orthogonal to slicing_dir) of the grid used to extract slices.
        
        num_slices: (Optional[int], default=10)
            The number of slices to extract from the labeled image.

    Returns:
    --------
        labeled_slices: (List[np.ndarray[int]])
            A list of slices obtained along slicing_dir direction.
        
        grid_specs: (Tuple[List[List[float]], List[float]]])
            A tuple consisting of lists of coordinates of grid centers and slicing directions (which
            is always the same). These values are used to identify the grids that have been used
            to sample labeled_slices from labeled_img.
    """
    
    # Define the centers of the sampling grids
    slicing_dir = np.asarray(slicing_dir)
    grid_centers = [
        slicing_dir * i + centroid 
        for i in np.linspace(-height, height, num_slices)
    ]

    # Compute the rotation matrix
    rot = _get_rotation(slicing_dir)

    # Store identifiers of the different grids (centers and direction)
    grid_specs = (
        [list(center) for center in grid_centers], 
        list(slicing_dir)
    )

    labeled_slices = []
    grid_coords = []
    for center in grid_centers:
        # Rotate and translate the grid
        grid_center_point = center
        sampling_coordinates = place_sampling_grids(
            grid_to_place, grid_center_point, rot
        )
        grid_coords.append(sampling_coordinates.reshape(-1, 3))

        # Sample values from the grid
        sampled_slice = sample_volume_at_coordinates(
            labeled_img,
            sampling_coordinates,
            interpolation_order=0,
        )

        labeled_slices.append(sampled_slice)

    return labeled_slices, grid_coords, grid_specs

In [None]:
def compute_2D_statistics_along_axes(
        labeled_img: np.ndarray[int],
        cell_mesh_dict: Dict[int, tm.base.Trimesh],
        exclude_labels: Iterable[int],
        voxel_size: Iterable[float],
        number_slices: int = 10, 
        slice_size: int = 200,
        remove_empty: Optional[bool] = True
) -> Tuple[Dict[int, List[List[int]]], 
           Dict[int, List[float]], 
           Dict[int, Dict[int, List[int]]], 
           Dict[int, Tuple[List[List[float]], List[float]]]]:
    
    if np.any(slice_size > np.asarray(labeled_img.shape)):
        slice_size = np.min(labeled_img.shape)
    
    print('Computing cell 2D statistics along apical-basal axis...')

    # Iterate over the cells
    label_ids = np.unique(labeled_img)

    # Compute principal axes, centroids and lengths for all the cells
    cell_centroids, cell_lengths, cell_principal_axes = {}, {}, {}
    cell_principal_vectors = {} # array of points on the direction of the principal axes
    for label_id in tqdm(label_ids[1:], desc='Computing principal axes and centroids'):
        if label_id in exclude_labels:
            cell_principal_axes[label_id] = None
            cell_centroids[label_id] = None
            cell_lengths[label_id] = None
        else:
            # Compute principal axis, axis length and centroid coordinates 
            cell_mesh = cell_mesh_dict[label_id]
            principal_axis = _get_principal_axis(
                mesh=cell_mesh,
                scale=voxel_size
            )
            cell_principal_axes[label_id] = principal_axis

            binary_img = (labeled_img == label_id).astype(np.uint8)
            cell_centroid, cell_length = _get_centroid_and_length(binary_img)
            cell_length = int(cell_length // 2)
            cell_centroids[label_id] = cell_centroid
            cell_lengths[label_id] = cell_length
            cell_principal_vectors[label_id] = np.asarray([
                principal_axis * i + cell_centroid 
                for i in np.linspace(-cell_length, cell_length, number_slices)
            ])
            # viewer.add_points(cell_principal_vectors[label_id], size=5, name=f'axis cell {label_id}')
    
    # Generate a grid of the desired size for sampling from the image
    grid_shape = [slice_size + 1] * 2
    grid = generate_2d_grid(grid_shape)

    neighbors_dict = {}
    areas_dict = {}
    neighbors_of_neighbors_dict = {}
    slices_dict = {}
    for label_id in tqdm(label_ids[1:], desc='Computing 2D statistics'):
        if label_id in exclude_labels:
            neighbors_dict[label_id] = []
            areas_dict[label_id] = []
            neighbors_of_neighbors_dict[label_id] = {}
            slices_dict[label_id] = ()
        else:
            # Get slices along principal axis direction
            labeled_slices, grid_coords, slices_specs = _get_slices_along_direction(
                labeled_img=labeled_img,
                slicing_dir=cell_principal_axes[label_id],
                centroid=cell_centroids[label_id],
                height=cell_lengths[label_id],
                grid_to_place=grid,
                num_slices=number_slices
            )

            # Iterate across slices to compute the statistics
            cell_areas = []
            cell_neighbors = []
            cell_neighbors_of_neighbors = defaultdict(list)
            for i, labeled_slice in enumerate(labeled_slices):
                area_in_slice = _compute_2D_area_along_direction(
                    labeled_slice=labeled_slice,
                    cell_label=label_id,
                    pixel_size=voxel_size[:2]
                )
                cell_areas.append(area_in_slice)

                neighbors_in_slice = _compute_2D_neighbors_along_direction(
                    labeled_slice=labeled_slice,
                    cell_label=label_id
                )
                cell_neighbors.append(neighbors_in_slice)
                
                # check for incomplete neighborhood
                if (neighbors_in_slice == [-1]) or np.any(np.isin(neighbors_in_slice, exclude_labels)): 
                    continue
                neighbors_of_neighbors_in_slice = _compute_neighbors_of_neighbors_along_direction(
                    labeled_img=labeled_img,
                    neighbors=neighbors_in_slice,
                    grid_coords=grid_coords[i],
                    principal_axis_pts=cell_principal_vectors,
                    principal_axes=cell_principal_axes,
                    grid_to_place=grid,
                    show_logs=False
                )
                if neighbors_of_neighbors_in_slice:
                    cell_neighbors_of_neighbors[len(neighbors_in_slice)].append(neighbors_of_neighbors_in_slice)
            
            if remove_empty:
                # Post-process results to remove empty values
                to_remove = [neighs == [-1] for neighs in cell_neighbors]
                cell_neighbors = [neighs for neighs, flag in zip(cell_neighbors, to_remove) if not flag]
                cell_areas = [area for area, flag in zip(cell_areas, to_remove) if not flag]
                new_slices_specs = [item for item, flag in zip(slices_specs[0], to_remove) if not flag]
                slices_specs = (new_slices_specs, slices_specs[1])
                
            neighbors_dict[label_id] = cell_neighbors
            areas_dict[label_id] = cell_areas
            neighbors_of_neighbors_dict[label_id] = dict(cell_neighbors_of_neighbors)
            slices_dict[label_id] = slices_specs

    return neighbors_dict, areas_dict, neighbors_of_neighbors_dict, slices_dict

### Test on a small sample made of few cells

In [None]:
# Take subsample of the image for quicker computation
labels = '185 173 128 135 141 147 150 159 169 171 175 180 184 186 187 189 195 208 210 211 213 214 228 237 243 246 270 271 318'
idxs = [int(lab) for lab in labels.split()]
intestine_img = imread('../outputs/outputs_v2/output_intestine_sample2_b_curated_segmentation_relabel_seq_s_10_e_6_d_8/processed_labels.tif')
cell_group = intestine_img * np.isin(intestine_img, idxs).astype(np.uint16)

In [None]:
meshes_dir = '../outputs/outputs_v2/output_intestine_sample2_b_curated_segmentation_relabel_seq_s_10_e_6_d_8/cell_meshes/'
file_names = os.listdir(meshes_dir)
meshes_files = [os.path.join(meshes_dir, mesh_file) for mesh_file in file_names]
meshes = {}
for idx in tqdm(idxs):
    mesh_file = [file for file in meshes_files if str(idx) in file][0]
    meshes[idx] = tm.load_mesh(mesh_file)

In [None]:
viewer = napari.Viewer()
viewer.add_labels(cell_group)

In [None]:
# # Collect principal axes, centroids and cell length throughout the image
# cell_centroids, cell_lengths, principal_vectors = {}, {}, {}
# for idx in tqdm(idxs):
#     cell_img = (cell_group == idx).astype(np.uint8)

#     # Principal axes and centroids
#     props = regionprops(cell_img)[0]
#     cell_centroids[idx] = props.centroid
#     cell_lengths[idx] = int(props.axis_major_length)

#     # Load the corresponding mesh
#     cell_mesh = tm.load_mesh(
#         f'../outputs/outputs_v2/output_intestine_sample2_b_curated_segmentation_relabel_seq_s_10_e_6_d_8/cell_meshes/cell_{idx}.stl'
#     )

#     # Compute the principal axis
#     eigen_values, eigen_vectors = tm.inertia.principal_axis(cell_mesh.moment_inertia)

#     # Get the index of the smallest eigen value
#     smallest_eigen_value_idx = np.argmin(np.abs(eigen_values))
#     greatest_eigen_value_idx = np.argmax(np.abs(eigen_values))

#     # Get the corresponding eigen vector 
#     principal_axis = eigen_vectors[smallest_eigen_value_idx]
#     principal_axis = principal_axis / np.array([0.325, 0.325, 0.1625])
#     principal_vectors[idx] = principal_axis / np.linalg.norm(principal_axis)

# #points along principal axes (for for centering grids)
# principal_axis_pts = {
#     id: np.asarray([vector * i + centroid for i in range(-100, 100, 20)]) 
#     for id, vector, centroid in zip(principal_vectors.keys(), principal_vectors.values(), cell_centroids.values()) 
# }

In [None]:
# from collections import defaultdict
# from time import time

# aboav_law_dict = {}
# for idx in idxs:
#     start = time()
#     print("-----------------------------------------")
#     print(f"Analyzing cell {idx}:")
#     # initialize sub dict for current cell
#     cell_dict = defaultdict(list)

#     # get values for this cell
#     principal_vector = principal_vectors[idx]
#     grid_centers = principal_axis_pts[idx]
#     cell_length = cell_lengths[idx]

#     # rotation matrix
#     rot = _get_rotation(principal_vector)

#     # create grid in the origin
#     grid_shape = [200 + 1] * 2
#     grid = generate_2d_grid(grid_shape)

#     # iterate over different points on the principal axis and place grids
#     slc_counter = 1
#     for grid_center in grid_centers:
#         print(f"    Computing stats for slice {slc_counter}/{len(grid_centers)}")
#         slc_counter += 1
#         # sample a slice
#         placed_grid = place_sampling_grids(grid, grid_center, rot)
#         grid_coords = placed_grid.reshape(-1, 3)
#         sampled_slice = sample_volume_at_coordinates(
#             cell_group,
#             placed_grid,
#             interpolation_order=0,
#         )

#         # compute neighbors
#         neighbors = _compute_2D_neighbors_along_direction(sampled_slice, idx)
#         if neighbors == [-1]:
#             print("        Incomplete neighborhood, skipping current slice...")
#             continue

#         print(f"        Cell neighbors: {neighbors}")
        
#         # iterate over neighbors to compute neighbors of neighbors
#         neigh_num_neighbors = []
#         for neighbor in neighbors:
#             print(f"            Computing neigbors of neighbor {neighbor}")
#             # get points on principal axis of neighboring cell
#             neigh_principal_pts = principal_axis_pts[neighbor]
#             neigh_principal_vector = principal_vectors[neighbor]
#             # get intersection between grid of main cell and points of neighbor principal axis
#             neigh_center = find_closest(grid_coords, neigh_principal_pts, 20)
#             # place grid and sample slice for neighbor
#             neigh_rot = _get_rotation(neigh_principal_vector)
#             neigh_placed_grid = place_sampling_grids(grid, neigh_center, neigh_rot)
#             neigh_sampled_slice = sample_volume_at_coordinates(
#                 cell_group,
#                 neigh_placed_grid,
#                 interpolation_order=0,
#             )
#             # compute number of neighbors for neighbor
#             neigh_neighbors = _compute_2D_neighbors_along_direction(neigh_sampled_slice, neighbor)
#             # if any neighbor of main cell doesn't have complete neighborhood go to next slice
#             if neigh_neighbors == [-1]:
#                 print("                Incomplete neighborhood, skipping current slice...")
#                 break
#             else:
#                 print(f"                Neighbor {neighbor} has {len(neigh_neighbors)} neighbors")
#                 neigh_num_neighbors.append(len(neigh_neighbors))

#         if len(neigh_num_neighbors) == len(neighbors):
#             cell_dict[len(neighbors)] = cell_dict[len(neighbors)] + neigh_num_neighbors     

#     print(f'Time elapsed: {time() - start}s.')

#     aboav_law_dict[idx] = cell_dict

In [None]:
results = compute_2D_statistics_along_axes(
    labeled_img=cell_group,
    cell_mesh_dict=meshes,
    exclude_labels=[],
    voxel_size=[0.325, 0.325, 0.25],
    remove_empty=True
)

In [None]:
results[2][214], results[0][214]

### Test on whole intestine sample

In [None]:
intestine_img = imread('../outputs/outputs_v2/output_intestine_sample2_b_curated_segmentation_relabel_seq_s_10_e_6_d_8/processed_labels.tif')

# Load meshes
labels = np.unique(intestine_img)
meshes_dict = {}
for label in tqdm(labels[1:]):
    meshes_dict[label] = tm.load_mesh(
        f'../outputs/outputs_v3/output_intestine_sample2_b_curated_segmentation_relabel_seq_s_10_e_6_d_8/cell_meshes/cell_{label}.stl'
    )

In [None]:
# Load excluded cells
with open('/nas/groups/iber/Users/Federico_Carrara/Statistics_Collection/outputs/outputs_v3/output_intestine_sample2_b_curated_segmentation_relabel_seq_s_10_e_6_d_8/cut_cells_labels.txt', 'r') as file:
    exclude_cells = file.readlines()

exclude_cells = [int(float(val.strip())) for val in exclude_cells]

In [None]:
results = compute_2D_statistics_along_axes(
    labeled_img=intestine_img,
    cell_mesh_dict=meshes_dict,
    exclude_labels=exclude_cells,
    voxel_size=[0.325, 0.325, 0.25],
    remove_empty=True
)

In [None]:
results[2]