## Test OuterShell

In [None]:
import os
import napari
import numpy as np
from scipy.spatial import Delaunay
from scipy.interpolate import LinearNDInterpolator, NearestNDInterpolator
import trimesh
import math
import open3d as o3d
from copy import deepcopy
from trimesh.creation import icosphere
from typing import List

In [None]:
### UTIL FUNCTION TO MAKE TEST SAMPLE OF SPHERICAL MESHES ###
def create_aligned_spheres(
        n: int,
        radius: float
) -> List[trimesh.Trimesh]:
    """
    Create n spherical meshes displaced one next to the other and with centers aligned.

    Parameters:
    -----------
    n: (int)
        The number of spheres to create
    radius: (float)
        The radius of the spheres

    Returns:
    --------
    spheres: (List[trimesh.Trimesh])
        A list of meshes of spheres
    """
    spheres = []
    displacement = np.array([0.0, 0.0, 0.0])
    for _ in range(n):
        curr_sphere = icosphere(radius=radius)
        curr_sphere.vertices += displacement
        displacement[0] += 2 * radius
        spheres.append(curr_sphere)

    return spheres

#### Get outer shell from single meshes

In [None]:
from OuterShell import ExtendedTrimesh, OuterShell

In [None]:
# ### SPHERES ###

# # Initialize synthtic input
# spheres = create_sample_spheres(3, [1, 1, 1])
# neighbors_lst = [[1], [0, 2], [1]]

# # Calculate the length of each edge
# edges = spheres[0].edges
# edge_lengths = spheres[0].vertices[edges[:, 0]] - spheres[0].vertices[edges[:, 1]]
# edge_lengths = trimesh.util.row_norm(edge_lengths)

# # Get the shortest edge length
# shortest_edge_length = min(edge_lengths)

# # Transform in ExtendedTrimesh format
# extended_meshes = []
# for i, sphere in enumerate(spheres):
#     extended_meshes.append(ExtendedTrimesh(neighbors=neighbors_lst[i], vertices=sphere.vertices, faces=sphere.faces))

In [None]:
### CELL MESHES ###
root_dir = r"N:\Users\Federico_Carrara\Meshes_for_Simulation\examples\cell_clump_intestine\cell_clumps\clean_clump_16_cells\clean_meshes"
# root_dir = "/nas/groups/iber/Users/Federico_Carrara/Meshes_for_Simulation/examples/cell_clump_intestine/cell_clumps/clean_clump_16_cells/clean_meshes"
mesh1 = trimesh.load_mesh(os.path.join(root_dir, "cell_193.stl"))
mesh2 = trimesh.load_mesh(os.path.join(root_dir, "cell_228.stl"))
mesh3 = trimesh.load_mesh(os.path.join(root_dir, "cell_171.stl"))

mesh_lst = [mesh1, mesh2, mesh3]
neighbors_lst = [[1, 2], [0, 2], [0, 1]]

# # Calculate the length of each edge
# edges = mesh1.edges
# edge_lengths = mesh1.vertices[edges[:, 0]] - mesh1.vertices[edges[:, 1]]
# edge_lengths = trimesh.util.row_norm(edge_lengths)

# # Get the shortest edge length
# shortest_edge_length = min(edge_lengths)
# print(shortest_edge_length)

# Transform in ExtendedTrimesh format
extended_meshes = []
for i, mesh in enumerate(mesh_lst):
    extended_meshes.append(ExtendedTrimesh(neighbors=neighbors_lst[i], vertices=mesh.vertices, faces=mesh.faces))

In [None]:
outer_shell = OuterShell(
    meshes=extended_meshes, 
    neighbors_lst=neighbors_lst
)

###### Plots for presentation

In [None]:
outer_shell.generate_outer_shell(displace_points=True)

In [None]:
outer_shell.mesh.export(r"N:\Users\Federico_Carrara\master_thesis_docs\IMGS\for_final_pres\outer_shell_example_wdispl.stl")

In [None]:
pc = outer_shell.points
pc_normals = outer_shell.point_normals

In [None]:
viewer = napari.Viewer()

#### Test 'get_shell_point_cloud' method

In [None]:
outer_shell.get_shell_point_cloud(dist_threshold=10)

In [None]:
# Check that k-closest points are correctly extracted
# k_closest_idxs_1 = np.concatenate([
#     outer_shell._meshes[0].k_closest_dict[1],
#     outer_shell._meshes[0].k_closest_dict[2],
# ])
# k_closest_idxs_2 = np.concatenate([
#     outer_shell._meshes[1].k_closest_dict[0],
#     outer_shell._meshes[1].k_closest_dict[2],
# ])
# k_closest_idxs_3 = np.concatenate([
#     outer_shell._meshes[2].k_closest_dict[0],
#     outer_shell._meshes[2].k_closest_dict[1],
# ])

# k_closest_pts_1 = outer_shell._meshes[0].points[k_closest_idxs_1]
# k_closest_pts_2 = outer_shell._meshes[1].points[k_closest_idxs_2]
# k_closest_pts_3 = outer_shell._meshes[2].points[k_closest_idxs_3]

# k_closest_idxs_1_2 = outer_shell._meshes[0].k_closest_dict[1]
# k_closest_idxs_1_3 = outer_shell._meshes[0].k_closest_dict[2]
# k_closest_idxs_2_1 = outer_shell._meshes[1].k_closest_dict[0]
# k_closest_idxs_2_3 = outer_shell._meshes[1].k_closest_dict[2]
# k_closest_idxs_3_1 = outer_shell._meshes[2].k_closest_dict[0]
# k_closest_idxs_3_2 = outer_shell._meshes[2].k_closest_dict[1]

# k_closest_pts_1_2 = outer_shell._meshes[0].points[k_closest_idxs_1_2]
# k_closest_pts_1_3 = outer_shell._meshes[0].points[k_closest_idxs_1_3]
# k_closest_pts_2_1 = outer_shell._meshes[1].points[k_closest_idxs_2_1]
# k_closest_pts_2_3 = outer_shell._meshes[1].points[k_closest_idxs_2_3]
# k_closest_pts_3_1 = outer_shell._meshes[2].points[k_closest_idxs_3_1]
# k_closest_pts_3_2 = outer_shell._meshes[2].points[k_closest_idxs_3_2]

# viewer = napari.Viewer()
# viewer.add_points(outer_shell.points, size=0.1)
# viewer.add_points(k_closest_pts_1_2, size=0.1, face_color="green")
# viewer.add_points(k_closest_pts_1_3, size=0.1, face_color="red")
# viewer.add_points(k_closest_pts_2_1, size=0.1, face_color="orange")
# viewer.add_points(k_closest_pts_2_3, size=0.1, face_color="purple")
# viewer.add_points(k_closest_pts_3_1, size=0.1, face_color="blue")
# viewer.add_points(k_closest_pts_3_2, size=0.1, face_color="brown")

#### Test Interpolation 

##### 1. Fit a global model, evaluate it on the gaps
Pros:
- Global model allows to capture features common to all the point cloud and extend them to the gaps

Cons:
- Parameters may depend on the particular sample/shape, and hence it would be difficult to generalize.
- It is not trivial to individuate gaps and to place appropriate sampling grids. 

In [None]:
# Train the chosen model on existing data
x, y, z = outer_shell.points[:, 0], outer_shell.points[:, 1], outer_shell.points[:, 2]
model = NearestNDInterpolator(np.column_stack([x, y]), z)

### Get grid of x, y values at the gaps
# 0. Compute the step of the grid as the average distance between points over all the meshes
grid_step = np.mean([mesh.mean_point_distance for mesh in outer_shell._meshes]) / math.sqrt(2)

# 1. Get all pairs of neighbors
neighbor_pairs = set()
for idx, neighbors in enumerate(outer_shell._neighbors_lst):
    for neighbor in neighbors:
        pair = tuple(sorted((idx, neighbor)))
        neighbor_pairs.add(pair)

# 2. For each pair:
new_shell_points = []
all_closest_points = []
pred_grids = []
for idx_1, idx_2 in neighbor_pairs:
    # 2.a. Get closest points for each cell in the pair
    mesh_1, mesh_2 = outer_shell._meshes[idx_1], outer_shell._meshes[idx_2]
    closest_point_idxs_1 = mesh_1.k_closest_dict[idx_2]
    closest_point_idxs_2 = mesh_2.k_closest_dict[idx_1]
    closest_points = np.concatenate(
        [mesh_1.points[closest_point_idxs_1], mesh_2.points[closest_point_idxs_2]]
    )
    all_closest_points.append(closest_points)

    # 2.b. Compute grid taking closest points extema on x and y
    max_x = np.max(closest_points[:, 0])
    min_x = np.min(closest_points[:, 0])
    num_x = int((max_x - min_x) / grid_step)
    max_y = np.max(closest_points[:, 1])
    min_y = np.min(closest_points[:, 1])
    num_y = int((max_y - min_y) / grid_step)
    x_grid = np.linspace(min_x, max_x, num_x)
    y_grid = np.linspace(min_y, max_y, num_y)
    X, Y = np.meshgrid(x_grid, y_grid)
    X, Y = X.ravel(), Y.ravel()
    pred_grids.append(np.column_stack([X, Y, np.zeros_like(X)]))

    # 3. Predict on the newly created grid
    Z_pred = model(X, Y)
    
    pred_points = np.column_stack([X, Y, Z_pred.ravel()])
    new_shell_points.append(pred_points)

new_shell_points = np.vstack(new_shell_points)
all_closest_points = np.vstack(all_closest_points)
pred_grids = np.vstack(pred_grids)

In [None]:
viewer = napari.Viewer()
viewer.add_points(new_shell_points, size=0.1, face_color="green")
viewer.add_points(outer_shell.points, size=0.1)
viewer.add_points(all_closest_points, size=0.1, face_color="red")
viewer.add_points(pred_grids, size=0.1, face_color="blue")

##### 2. KNN interpolation
IDEA: Instead of placing grids arbitrarily on the gaps, do the following:
- Compute the distance from each point to its nearest neighbor (edge length)
- Now consider the points for which edge length exceeds a certain threshold (e.g. 2 * std). These points are likely to be placed on gaps.
- For each of such points, find the KNN and place a new point in the coordinate barycenter of the neighbors.

The result should be an evenly spaced grid!

CAZZATA -> If I consider the whole outer shell point cloud, I cannot find gaps in this way as the nearest neigbor for a given point will always be on the same mesh and not on a neighboring one. 

In [None]:
pc = o3d.geometry.PointCloud()
pc.points = o3d.utility.Vector3dVector(outer_shell.points)
nearest_distances = np.asarray(pc.compute_nearest_neighbor_distance())
mean_nearest_distances, std_nearest_distances = np.mean(nearest_distances), np.std(nearest_distances)

In [None]:
gap_mask = np.logical_or(
    nearest_distances > mean_nearest_distances + 2 * std_nearest_distances,
    nearest_distances < mean_nearest_distances - 2 * std_nearest_distances
)
gap_points = outer_shell.points[gap_mask]

In [None]:
viewer = napari.Viewer()
viewer.add_points(outer_shell.points, size=0.1)
viewer.add_points(gap_points, size=0.1, face_color="red")

In [None]:
outer_shell.interpolate_gaps('spline')

##### 3. Linear Interpolation between mesh borders
ALGORITHM:
- Find "borders" of cell meshes (i.e., single layers of points that are the closest to another mesh). 
- For each mesh border:
    - For each point of the border:
        - Find closest point on the neighboring cell border
        - Sample points on the line that join these 2 points 

In [None]:
# # 1. Get all pairs of neighbors
# neighbor_pairs = set()
# for idx, neighbors in enumerate(outer_shell._neighbors_lst):
#     for neighbor in neighbors:
#         pair = tuple(sorted((idx, neighbor)))
#         neighbor_pairs.add(pair)

# # colors = ["blue", "cyan", "green", "yellow", "red", "purple"]
# # viewer = napari.Viewer()
# # viewer.add_points(outer_shell.points, size=0.1)

# # 2. Iterate over the neighbor pairs
# all_sampled_points = []
# for i, (idx_1, idx_2) in enumerate(neighbor_pairs):
#     # 2.a. Get closest points for each cell in the pair
#     mesh_1, mesh_2 = outer_shell._meshes[idx_1], outer_shell._meshes[idx_2]
#     closest_point_idxs_1 = mesh_1.k_closest_dict[idx_2]
#     closest_point_idxs_2 = mesh_2.k_closest_dict[idx_1]
#     closest_points_1 = mesh_1.points[closest_point_idxs_1]
#     closest_points_2 = mesh_2.points[closest_point_idxs_2]
#     closest_points_kdtree_1 = mesh_1.k_closest_kdtrees[idx_2]
#     closest_points_kdtree_2 = mesh_2.k_closest_kdtrees[idx_1]

#     # 2.b. For each point in closest_points_1, find the border point in closest_points_2 set using KDTree 
#     _, border_idxs_2 = closest_points_kdtree_2.query(closest_points_1, k=1)
#     border_idxs_2 = list(set(border_idxs_2))
#     border_points_2 = closest_points_2[border_idxs_2]

#     # 2.c. Now for each border point of mesh 2, get the associated border point of mesh 1
#     _, border_idxs_1 = closest_points_kdtree_1.query(border_points_2, k=1)

#     # print(border_idxs_1[:5], border_idxs_2[:5])
#     # print(border_points_2[:5])

#     # 2.d. Store the pair of indices in tuples to avoid mixing up the pairs
#     border_idxs_pairs = [
#         (b_idx1, b_idx2) 
#         for b_idx1, b_idx2 in zip(border_idxs_1, border_idxs_2)
#     ]
#     # Store in a set to remove duplicates
#     border_idx_pairs = set(border_idxs_pairs)
#     # Get ordered border points for mesh 1
#     border_points_1 = closest_points_1[[pair[0] for pair in border_idxs_pairs]]

#     # 2.e. Compute the direction vectors for each pair of border points
#     direction_vectors = border_points_2 - border_points_1

#     # Sample points along the direction vectors
#     num_samples = 3
#     sampling_steps = np.linspace(0, 1, num_samples + 2)[1:-1, np.newaxis]
#     sampled_points = border_points_1[:, np.newaxis, :] +  sampling_steps * direction_vectors[:, np.newaxis, :]
#     sampled_points = sampled_points.reshape(-1, 3)
#     all_sampled_points.append(sampled_points)

#     # viewer.add_points(border_points_1, size=0.2, face_color=colors[i*2])
#     # viewer.add_points(border_points_2, size=0.2, face_color=colors[i*2+1])
#     # viewer.add_points(sampled_points, size=0.2, face_color="pink")
#     # viewer.add_points(border_points_1[:10], size=0.3, face_color=colors[i * 2])
#     # viewer.add_points(border_points_2[:10], size=0.3, face_color=colors[i * 2 + 1])

# all_sampled_points = np.concatenate(all_sampled_points)
# # outer_shell.points = np.vstack([outer_shell.points, all_sampled_points])

# # # viewer = napari.Viewer()
# # # viewer.add_points(outer_shell.points, size=0.1)

In [None]:
outer_shell.interpolate_gaps('1', 6)

#### Test creation of mesh from point cloud

In [None]:
# Estimate normals
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(outer_shell.points)
pcd.estimate_normals(fast_normal_computation=False)
pcd.orient_normals_consistent_tangent_plane(k=10)

# point_normals = outer_shell.points + np.asarray(pcd.normals)

# viewer = napari.Viewer()
# viewer.add_points(outer_shell.points, size=0.2)
# viewer.add_points(point_normals, size=0.2, face_color="red")

In [None]:
# estimate radius for rolling ball
# distances = pcd.compute_nearest_neighbor_distance()
# avg_dist = np.mean(distances)
# radius = 2 * avg_dist   

# mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(
#            pcd,
#            o3d.utility.DoubleVector([radius, radius * 2])
# )

mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
    pcd, depth=8, width=0, scale=1.1, linear_fit=False)[0]

# create the triangular mesh with the vertices and faces from open3d
tri_mesh = trimesh.Trimesh(np.asarray(mesh.vertices), np.asarray(mesh.triangles),
                          vertex_normals=np.asarray(mesh.vertex_normals))

viewer = napari.Viewer()
# viewer.add_points(outer_shell.points, size=0.2, face_color="red")
viewer.add_surface((tri_mesh.vertices, tri_mesh.faces))

In [None]:
poisson_mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=8, width=0, scale=1.1, linear_fit=False)[0]

#### Test wrapper method 'generate_outer_shell'

In [None]:
outer_shell.generate_outer_shell()

In [None]:
viewer = napari.Viewer()
viewer.add_surface((outer_shell.mesh.vertices, outer_shell.mesh.faces))

In [None]:
outer_shell.mesh_to_file("/home/fcarrara/Documents/outer_shell_3_cells.stl")

#### Test 'MeshPrep.create_shell_from_image'

In [None]:
from MeshPrep import create_shell_from_meshes
from tqdm import tqdm

In [None]:
cell_ids = [128, 130, 138, 139, 147, 150, 163, 167, 169, 171, 180, 185, 187, 193, 210, 228]
# root = "/nas/groups/iber/Users/Federico_Carrara"
root = r"N:/Users/Federico_Carrara/"
rel_path_to_df = "./Statistics_Collection/outputs/outputs_v5/output_intestine_sample2_b_curated_segmentation_relabel_seq_s_10_e_6_d_8/cell_stats/stats_dataset_intestine_villus.csv"
rel_path_to_meshes = "./Meshes_for_Simulation/examples/cell_clump_intestine/cell_clumps/clean_clump_16_cells/"
path_to_df = os.path.join(root, rel_path_to_df)
path_to_meshes = os.path.join(root, rel_path_to_meshes)

# Load meshes
mesh_lst = []
for id in tqdm(cell_ids):
    fname = f"cell_{id}.stl"
    mesh_lst.append(trimesh.load(os.path.join(path_to_meshes, "clean_meshes", fname)))

shell_mesh = create_shell_from_meshes(
    meshes=mesh_lst,
    cell_idxs=cell_ids,
    cell_stats_data_path=path_to_df,
    w_displacement=False
)

In [None]:
viewer = napari.Viewer()
# viewer.add_points(shell_mesh.points, size=0.15, face_color="red")
# viewer.add_points(shell_mesh.point_normals, size=0.15, face_color="pink")
viewer.add_surface((shell_mesh.vertices, shell_mesh.faces))

In [None]:
shell_mesh.export(os.path.join(root, "new_shell_nodispl.stl"))