# Imports

In [1]:
%cd ..

/hdd/aouadt/these/projets/3d_segm


In [2]:
from importlib import reload
import os
import random
import pathlib
import re
from collections import defaultdict
import shutil

import nibabel as nib
import cv2
from time import time
from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import skimage.morphology as morp
import skimage.measure as meas
from scipy.ndimage import rotate
from scipy.spatial import procrustes
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import trimesh
import pyqtgraph
import networkx as nx
import open3d as o3d
from open3d.web_visualizer import draw

import general.open3d_utils as ou
import general.queue as qu
import general.utils as u
import general.array_morphology as am
import ssm.best_fit_transform as bft
import ssm.utils as su
import ssm.icp as icp
import ssm.hungarian_icp as hicp
import ssm.dijkstra as dij
import ssm.sample_mesh as smesh
import ssm.shape as sh

def reload_modules():
    for modl in [bft, ou, qu, u, am, su, icp, hicp, dij, smesh, sh,]:
        reload(modl)

reload_modules()

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
[Open3D INFO] Resetting default logger to print to terminal.


In [3]:
%matplotlib notebook

In [4]:
def get_normal_lineset(pcd, normals):
    pcdn = ou.numpy_to_o3d_pcd(np.asarray(pcd.points) + normals, normals=normals)
    line_set = o3d.geometry.LineSet()
    line_set.points = (pcd + pcdn).points
    line_set.lines = o3d.utility.Vector2iVector([[i, i + len(normals)] for i in range(len(normals))])
    return line_set

def get_lineset_matches(pcd1, pcd2):
    n_points = len(np.asarray(pcd1.points))
    line_set = o3d.geometry.LineSet()
    line_set.points = (pcd1 + pcd2).points
    line_set.lines = o3d.utility.Vector2iVector(
        [[i, n_points + i] for i in range(n_points)]
    )
    return line_set

In [121]:
MESH_DIR = "/hdd/datasets/CT-ORG/meshes/labels/"
STEP_SIZE = 4


blacklist = [
    '/hdd/datasets/CT-ORG/meshes/labels/labels-20/step_size_4',
    '/hdd/datasets/CT-ORG/meshes/labels/labels-19/step_size_4',
]

all_meshes = su.sort_by_regex(list(set([
    os.path.join(MESH_DIR, filename, f"step_size_{STEP_SIZE}") for filename in os.listdir(MESH_DIR)
]).difference(blacklist)))



# Test Open3d

In [None]:
cube_red = o3d.geometry.TriangleMesh.create_box(1, 2, 4)
cube_red.compute_vertex_normals()
cube_red.paint_uniform_color((1.0, 0.0, 0.0))
draw(cube_red)

In [22]:
%%time
path_segm = os.path.abspath("/hdd/datasets/CT-ORG/raw/labels_and_README/labels-11.nii.gz")
seg1n = nib.load(path_segm)
seg1 = np.round(seg1n.get_fdata()) == 2
reg1 = (u.get_most_important_regions(seg1) > 0).astype(int)
verts1, faces1, normals1, values1 = meas.marching_cubes(reg1, step_size=2)
print(len(verts1), len(faces1))

11328 22652
CPU times: user 11.8 s, sys: 7.74 s, total: 19.5 s
Wall time: 9.87 s


In [None]:
pcd = numpy_to_o3d_pcd(verts1)
draw(pcd)

In [None]:
msh = numpy_to_o3d_mesh(verts1, faces1, normals1)
draw([msh])

# 0) Sampling

### 2D Tests

In [5]:
def check_same_dict(dic1, dic2):
    assert len(dic1) == len(dic2)
    same = True
    for key in dic1:
        same = same and dic1[key] == dic2[key]
    return same

In [6]:
reload_modules()
def create_grid_graph(W, H):
    graph = dij.Graph()

    for i in range(1, W - 1):
        for j in range(1, H - 1):
            if (i, j) in graph.nodes:
                continue
            graph.add_node((i, j))
            for (k1, k2) in [(-1, 0), (1, 0), (0, 1), (0, -1),]:
                graph.add_node((i + k1, j + k2))
                value = 1
                graph.add_edge((i, j), (i + k1, j + k2), value)
                graph.add_edge((i + k1, j + k2), (i, j), value)
            for (k1, k2) in [(1, 1), (1, -1), (-1, 1), (-1, -1)]:
                graph.add_node((i + k1, j + k2))
                value = np.sqrt(2)
                graph.add_edge((i, j), (i + k1, j + k2), value)
                graph.add_edge((i + k1, j + k2), (i, j), value)
                             
    return graph

In [None]:
%%time
reload_modules()

n = 100
graph = create_grid_graph(n, n)

t1 = time()
visited, path, closest = dij.dijkstra(graph, [(n//2, n//2)])
print("Dijkstra Time:", time() - t1)

points = []
ar = np.zeros((n, n))
for (i, j) in visited:
    ar[i, j] = visited[(i, j)]

fig = plt.figure()

ax = fig.add_subplot(121)
ax.imshow(ar)

ax = fig.add_subplot(122)
ax.imshow(np.cos(.5 * ar))

### On Data

In [6]:
%%time
# path_segm = os.path.abspath("/hdd/datasets/CT-ORG/raw/labels_and_README/labels-11.nii.gz")
# seg1n = nib.load(path_segm)
# seg1 = np.round(seg1n.get_fdata()) == 2
# reg1 = (u.get_most_important_regions(seg1) > 0).astype(int)
# verts1, faces1, normals1, values1 = meas.marching_cubes(reg1, step_size=2)
# print(len(verts1), len(faces1))
path_shape = '/hdd/datasets/CT-ORG/meshes/labels/labels-11/step_size_2'
cur_shape = sh.Shape.load_from_path(path_shape)
print(len(cur_shape.vertexes), len(cur_shape.faces))

11328 22652
CPU times: user 4.15 ms, sys: 4.34 ms, total: 8.49 ms
Wall time: 3.71 ms


In [None]:
reload_modules()
cur_msh = cur_shape.o3d_mesh()
draw(cur_msh)

In [None]:
# Parcours en largeur

# Q = Queue()
# fronts = []

# node = 0
# already_visited = set()

# gmesh = sm.create_mesh_graph(shape_to_plots[0].vertexes, shape_to_plots[0].faces)
# closest = shape_to_plots[0].closest_sample_point

# for child in gmesh.edges[node]:
#     already_visited.add(child)
#     Q.add((child, node))

# i = 0
# while not Q.is_empty():
#     node, parent = Q.pop()
#     for child in gmesh.edges[node].difference(already_visited):
#         already_visited.add(child)
#         Q.add((child, node))
#     if closest[node] != closest[parent]:
#         fronts.append(node)

#     i += 1

#     if i > len(gmesh.nodes):
#         print('Too many nodes !! breaking.')
#         break


In [10]:
reload_modules()
gmesh = smesh.create_mesh_graph(cur_shape.vertexes, cur_shape.faces)
# dists, path, closest = dij.dijkstra(gmesh, initial_set=[0])
all_points, ar_dist, ar_clos = smesh.dijkstra_sampling(gmesh=gmesh, n_points=100, )

100%|██████████| 100/100 [00:00<00:00, 138.67it/s]


In [11]:
# Voronoi cells of a subset of the graph

visited_edges = set()

neighbors = defaultdict(set)
gnei = nx.Graph()

i = 0
for node1 in gmesh.nodes:
    i += 1
    for node2 in gmesh.edges[node1]:
        if (node2, node1) in visited_edges:
            continue
        visited_edges.add((node1, node2))
#         visited_edges.add((node2, node1))
#         c1 = closest[node1]
#         c2 = closest[node2]
        c1, c2 = ar_clos[node1], ar_clos[node2]
        if c1 != c2:
            neighbors[c1].add(c2)
            neighbors[c2].add(c1)
            gnei.add_edge(c1, c2)
    
print(len(visited_edges), i)

33978 11328


In [12]:
plt.subplot(111)
nx.draw(gnei, with_labels=False)

<IPython.core.display.Javascript object>

In [28]:
len_cliques = [len(clique) for clique in nx.enumerate_all_cliques(gnei)]
plt.hist(len_cliques)

<IPython.core.display.Javascript object>

(array([100.,   0.,   0.,   0.,   0., 294.,   0.,   0.,   0., 196.]),
 array([1. , 1.2, 1.4, 1.6, 1.8, 2. , 2.2, 2.4, 2.6, 2.8, 3. ]),
 <BarContainer object of 10 artists>)

In [None]:
reload_modules()
faces_sample = np.array(
    [clique for clique in nx.enumerate_all_cliques(gnei) if len(clique) == 3]
)

ou.plot_mesh(cur_shape.vertexes, faces_sample)

In [7]:
all_points, ar_dist, ar_clos = smesh.dijkstra_sampling(verts1, faces1, n_points=100, )

100%|██████████| 100/100 [00:00<00:00, 133.80it/s]


In [8]:
# %%
fig = plt.figure()
ax = fig.add_subplot(121, projection='3d')

ax.scatter(*verts1.T, c=ar_dist)
ax.scatter(*verts1[all_points].T, c='r', s=100, alpha=1)

ax = fig.add_subplot(122, projection='3d')
ax.scatter(*verts1.T, c=ar_clos)



<IPython.core.display.Javascript object>

<mpl_toolkits.mplot3d.art3d.Path3DCollection at 0x7f1dca6cae80>

In [None]:
cur_shape.perform_sampling(n_points=1000, verbose=True)
cur_shape.compute_sample_faces()

mesh_sample = ou.numpy_to_o3d_mesh(vertices=cur_shape.sample, triangles=cur_shape.faces_sample)
mesh_sample.compute_vertex_normals()

cur_mesh = cur_shape.o3d_mesh()
cur_sample = ou.get_o3d_pcd_colored(cur_shape.sample)

draw([
    mesh_sample,
    cur_mesh,
    cur_sample
])

# I) Training: Shape creation

## Test ICP registration

In [35]:
# %%

path_segm = os.path.abspath("/hdd/datasets/CT-ORG/raw/labels_and_README/labels-11.nii.gz")
seg1n = nib.load(path_segm)
seg1 = np.round(seg1n.get_fdata()) == 2

path_segm = os.path.abspath("/hdd/datasets/CT-ORG/raw/labels_and_README/labels-10.nii.gz")
seg2n = nib.load(path_segm)
seg2 = np.round(seg2n.get_fdata()) == 2


In [36]:
%%time
reload_modules()
reg1 = (u.get_most_important_regions(seg1) > 0).astype(int)
reg2 = (u.get_most_important_regions(seg2) > 0).astype(int)



CPU times: user 16.9 s, sys: 674 ms, total: 17.6 s
Wall time: 17.6 s


In [37]:
%%time
verts1, faces1, normals1, values1 = meas.marching_cubes(reg1, step_size=4)
verts2, faces2, normals2, values2 = meas.marching_cubes(reg2, step_size=4)


CPU times: user 405 ms, sys: 108 ms, total: 513 ms
Wall time: 512 ms


In [None]:
# %%
fig = plt.figure()
ax = fig.add_subplot(121, projection='3d')
ax.plot_trisurf(*verts1.T, triangles=faces1)
ax.plot_trisurf(*verts2.T, triangles=faces2)

ax = fig.add_subplot(122, projection='3d')
ax.scatter(*verts1.T,)
ax.scatter(*verts2.T,)

In [39]:
mean1, std1 = verts1.mean(0), verts1.std(0)
mean2, std2 = verts2.mean(0), verts2.std(0)

nverts1 = (verts1 - mean1)/std1
nverts2 = (verts2 - mean2)/std2

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(*nverts1.T,)
ax.scatter(*nverts2.T,)

In [41]:
%%time
reload_modules()
T, errs, n_iters = icp.icp(nverts1, nverts2, allow_reflection=False, max_iterations=1000, tolerance=1e-5)
print(T, np.max(errs), n_iters)

[[ 0.99950916  0.02791258  0.01422405  0.01002564]
 [-0.02296148  0.96159541 -0.27350874 -0.13305637]
 [-0.02131212  0.27304788  0.96176434 -0.0616374 ]
 [ 0.          0.          0.          1.        ]] 0.925477981364535 60
CPU times: user 348 ms, sys: 7.21 ms, total: 355 ms
Wall time: 353 ms


In [None]:
Tverts1 = su.transform_cloud(T, nverts1)

fig = plt.figure(figsize=(15, 7))
ax1 = fig.add_subplot(121, projection='3d')
ax1.scatter(*Tverts1.T,)
ax1.scatter(*nverts2.T,)

ax2 = fig.add_subplot(122, projection='3d')
ax2.plot_trisurf(*Tverts1.T, triangles=faces1)
ax2.plot_trisurf(*nverts2.T, triangles=faces2)

In [111]:
Tn1 = su.get_norm_transform(mean1, std1, invert=False)
Tn2 = su.get_norm_transform(mean2, std2, invert=True)
T2 = Tn2 @ T @ Tn1

In [None]:


# T2verts1 = transform_cloud(T2, verts1 - mean1)
T2verts1 = su.transform_cloud(T2, verts1)

fig = plt.figure(figsize=(14, 7))
ax1 = fig.add_subplot(121, projection='3d')
ax1.scatter(*T2verts1.T, label="T1")
ax1.scatter(*verts2.T, label="2")
ax1.legend()

ax1 = fig.add_subplot(122, projection='3d')
ax1.scatter(*Tverts1.T, label="T1")
ax1.scatter(*nverts2.T, label="2")
ax1.legend()

## Test Load Meshes

In [None]:
mesh_dir = "/hdd/datasets/CT-ORG/meshes/labels/"
dest = os.path.join(mesh_dir, random.choice(os.listdir(mesh_dir)), 'step_size_1')
print(dest)

verts = np.load(os.path.join(dest, "vertexes.npy"))
faces = np.load(os.path.join(dest, "faces.npy"))

fig = plt.figure(figsize=(7, 7))
ax1 = fig.add_subplot(111, projection='3d')
ax1.plot_trisurf(*verts.T, triangles=faces)

# ax1 = fig.add_subplot(122, projection='3d')
# ax1.scatter(*verts.T)

## Uniform Sampling

In [9]:
mesh_dir = "/hdd/datasets/CT-ORG/meshes/labels/"
step_size = 4

all_meshes = su.sort_by_regex([
    os.path.join(mesh_dir, filename, f"step_size_{step_size}") for filename in os.listdir(mesh_dir)
])
ref_mesh = all_meshes[3]
print(ref_mesh)
ref_verts = np.load(os.path.join(ref_mesh, "vertexes.npy"))
ref_faces = np.load(os.path.join(ref_mesh, "faces.npy"))

sampling = smesh.dijkstra_sampling(ref_verts, ref_faces, 30)


Tnref = su.get_norm_transform(ref_verts.mean(0), ref_verts.std(0))
nref_verts = su.transform_cloud(Tnref, ref_verts)




  0%|          | 0/30 [00:00<?, ?it/s]

/hdd/datasets/CT-ORG/meshes/labels/labels-3/step_size_4


100%|██████████| 30/30 [00:00<00:00, 269.49it/s]


In [None]:
fig = plt.figure()
ax1 = fig.add_subplot(121, projection='3d')
ax1.plot_trisurf(*ref_verts.T, triangles=ref_faces)

ax1 = fig.add_subplot(122, projection='3d')
ax1.scatter(*ref_verts[sampling[0]].T, c='r', s=200)
ax1.scatter(*ref_verts.T, c=sampling[1])

In [None]:
ref_msh = ou.numpy_to_o3d_mesh(
    vertices=ref_verts,
    triangles=ref_faces,
    vertex_colors=u.colormap_1d(u.max_min_norm(sampling[1])),
)
ref_msh.compute_vertex_normals()

ref_pcd = ou.numpy_to_o3d_pcd(ref_verts[sampling[0]])
ref_pcd.paint_uniform_color([0, 1, 0])

draw([ref_pcd, ref_msh])

## Performing Registration Using Shape Methods

### Registration

In [17]:
# Analyze normals and normals transforms

reload_modules()
N_SAMPLES = 1000
REF_IDX = 2
CUR_IDX = 5

# ref_mesh = all_meshes[REF_IDX]
ref_mesh = '/hdd/datasets/CT-ORG/meshes/labels/labels-2/step_size_2'
print(ref_mesh)

ref_shape = sh.Shape.load_from_path(ref_mesh, label='reference')
ref_shape.perform_sampling(N_SAMPLES, verbose=True)
ref_shape.Tref = np.eye(4)

# cur_mesh = all_meshes[CUR_IDX]
cur_mesh = '/hdd/datasets/CT-ORG/meshes/labels/labels-26/step_size_2'
print(cur_mesh)

cur_shape = sh.Shape.load_from_path(cur_mesh, label='current')
cur_shape.perform_sampling(N_SAMPLES, verbose=True);
cur_shape.set_reference(ref_shape)


/hdd/datasets/CT-ORG/meshes/labels/labels-2/step_size_2


100%|██████████| 1000/1000 [00:03<00:00, 264.29it/s]


/hdd/datasets/CT-ORG/meshes/labels/labels-26/step_size_2


100%|██████████| 1000/1000 [00:01<00:00, 922.84it/s]


<ssm.shape.Shape at 0x7f574a359250>

In [18]:
%%time
t1 = time()
Tref, sample_idx, errs, n_iters = cur_shape.register_icp_to_reference()
print('Registration time', time() - t1)
cur_shape.match_samples(matching_method="perfect");
print(cur_shape.Tref)
print(n_iters)

Registration time 40.29895997047424
[[-7.34473634e-35 -2.85110585e-34  1.19465351e-34  2.52171227e+02]
 [ 0.00000000e+00 -1.22791064e-34 -2.93047582e-34  2.44570905e+02]
 [ 3.09127831e-34 -6.77409753e-35  2.83844228e-35  7.12342187e+01]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  1.00000000e+00]]
999
CPU times: user 3min 19s, sys: 12min 51s, total: 16min 11s
Wall time: 45 s


In [49]:
cur_msh = cur_shape.o3d_mesh()
cur_pcd = cur_shape.o3d_pcd()
ref_msh = cur_shape.o3d_ref_mesh_transformed()
ref_pcd = cur_shape.o3d_ref_pcd_transformed(point_color='r')
draw([cur_pcd, ref_pcd, cur_msh, ref_msh])

WebVisualizer(window_uid='window_9')

In [None]:
pcd_ref = ou.numpy_to_o3d_pcd(cur_shape.reference.sample)
Tpcd_ref = ou.numpy_to_o3d_pcd(su.transform_cloud(cur_shape.Tref, cur_shape.reference.sample))
pcd_cur = ou.numpy_to_o3d_pcd(cur_shape.sample)
lineset = get_lineset_matches(Tpcd_ref, pcd_cur)

Tlineset = get_lineset_matches(pcd_ref, Tpcd_ref)

pcd_ref.paint_uniform_color([0, 0, 1])
Tpcd_ref.paint_uniform_color([1, 0, 0])
pcd_cur.paint_uniform_color([0, 1, 0])

draw([
#     pcd_ref,
    Tpcd_ref,
    pcd_cur,
    lineset,
#     Tlineset
])

In [85]:
df = pd.DataFrame(np.asarray(pcd_cur.points))
print(len(df.drop_duplicates()))
print(len(cur_shape.sample_idx))

557
1000


### Check normals

In [None]:
# Checking if normals are OK

pcd_ref = ou.numpy_to_o3d_pcd(ref_shape.vertexes, normals=ref_shape.normals)
# Marching cubes normals
lineset = get_normal_lineset(pcd_ref, 2*ref_shape.normals)

msh_ref = ou.numpy_to_o3d_mesh(
    vertices=ref_shape.vertexes,
    triangles=ref_shape.faces,
    vertex_colors=u.colormap_1d(u.max_min_norm(ref_shape.dist_to_sample)),
)
msh_ref.compute_vertex_normals()
# Open3d mesh normals
lineset2 = get_normal_lineset(pcd_ref, -2*np.asarray(msh_ref.vertex_normals))

msh_Tref = ou.numpy_to_o3d_mesh(
    vertices=su.transform_cloud(cur_shape.Tref, ref_shape.vertexes),
    triangles=ref_shape.faces,
    vertex_colors=u.colormap_1d(u.max_min_norm(ref_shape.dist_to_sample)),
)
msh_Tref.compute_vertex_normals()
pcd_cur = ou.numpy_to_o3d_pcd(cur_shape.sample, normals=cur_shape.normals[cur_shape.sample_idx])
pcd_Tref = ou.numpy_to_o3d_pcd(
    su.transform_cloud(cur_shape.Tref, ref_shape.sample), 
)

# Rotated Marching cubes normals
lineset_tref = get_normal_lineset(
    pcd_Tref, normals=ref_shape.normals[ref_shape.sample_idx] @ cur_shape.Rotref.T)
# Rotated [open3d mesh] normals
lineset_tref2 = get_normal_lineset(
    pcd_Tref, normals=-np.asarray(msh_ref.vertex_normals)[ref_shape.sample_idx]@cur_shape.Rotref.T)
# Open3d [rotated mesh] normals
lineset_tref3 = get_normal_lineset(
    pcd_Tref, normals=-np.asarray(msh_Tref.vertex_normals)[ref_shape.sample_idx])


pcd_Tref.paint_uniform_color([1, 0, 0])
pcd_cur.paint_uniform_color([0, 1, 0])

draw([
#     pcd_ref,
#     lineset,
#     lineset2,
#     msh_ref
#     pcd_cur,
#     pcd_Tref,
    lineset_tref,
    lineset_tref2,
    lineset_tref3,
    msh_Tref
])

In [88]:
# Analytical comparison of normals

verts = cur_shape.reference.vertexes
faces = cur_shape.reference.faces
Rot = cur_shape.Rotref

ref_msh = ou.numpy_to_o3d_mesh(
    vertices=verts,
    triangles=faces
)
ref_msh.compute_vertex_normals()

Tref_msh2 = ou.numpy_to_o3d_mesh(
    vertices=verts@Rot.T,
    triangles=cur_shape.reference.faces
)
Tref_msh2.compute_vertex_normals()

norms = np.asarray(ref_msh.vertex_normals)
Rotnorms = np.asarray(Tref_msh2.vertex_normals)

Rotnorms2 = norms @ Rot.T

np.abs(Rotnorms - Rotnorms2).sum()

2.073898442301464e-11

## Creating shape model

### Computing shapes

In [5]:
label_df = pd.read_csv("/hdd/datasets/CT-ORG/meshes/labels_description.csv")

In [6]:
# Using shape class

reload_modules()
n_samples = 1000
STEP_SIZE = 2


candidates = label_df[
    (label_df['n_vertexes'] > n_samples) & (label_df['step_size'] == STEP_SIZE)]

ref_mesh = candidates['full_path'].iloc[2]
print(ref_mesh)
ref_shape = sh.Shape.load_from_path(ref_mesh, Tref=np.eye(4))
ref_shape.perform_sampling(n_samples, verbose=True);

/hdd/datasets/CT-ORG/meshes/labels/labels-2/step_size_2


100%|██████████| 1000/1000 [00:04<00:00, 240.86it/s]


In [7]:
# Using classic ICP (nearest neighbor) with Perfect Matching Sampling

n_shapes = 10
all_shapes = []
all_errs = []
all_n_iters = []
not_converged = []

# iterator = su.sort_by_regex(list(set(all_meshes).difference([ref_mesh])))[10:10+n_shapes]
iterator = list(candidates['full_path'].iloc[10:10+n_shapes])
for idx in tqdm(range(len(iterator))):
    mesh_path = iterator[idx]
    cur_shape = sh.Shape.load_from_path(mesh_path, reference=ref_shape)
    
    _, _, errs, n_iters = cur_shape.register_icp_to_reference(max_iterations=100)
    
    if n_iters == 99:
        not_converged.append([idx, cur_shape])
    elif np.linalg.norm(cur_shape.Tref[:3, :3]) < 1e-5:
        not_converged.append([idx, cur_shape])
    
    else:
        cur_shape.perform_sampling(len(cur_shape.reference.sample_idx))
        cur_shape.match_samples(matching_method="perfect")

        all_shapes.append(cur_shape)
        all_errs.append(errs)
        all_n_iters.append(n_iters)
        
print("Not converged:", len(not_converged))

100%|██████████| 10/10 [01:17<00:00,  7.75s/it]

Not converged: 0





In [8]:
cur_shape = random.choice(all_shapes)
print(cur_shape.label, cur_shape.origin_path)
print(cur_shape.Tref)
print(len(cur_shape.vertexes))

fig = plt.figure(figsize=(7, 7))
ax1 = fig.add_subplot(121, projection='3d')
cur_shape.plot_compare_point_cloud(ax1, show_sampling=False)

ax1 = fig.add_subplot(122, projection='3d')
cur_shape.plot_compare_samples(ax1)

None /hdd/datasets/CT-ORG/meshes/labels/labels-18/step_size_2
[[ 5.82941621e-01 -1.30056673e-03 -4.59398049e-03  1.09861036e+02]
 [ 1.28397696e-03  5.82955944e-01 -2.10917144e-03  1.44878389e+02]
 [ 4.59864477e-03  2.09898241e-03  5.82939256e-01  1.90942949e+01]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  1.00000000e+00]]
11583


<IPython.core.display.Javascript object>

<Axes3DSubplot:>

In [10]:
# An O3D look at the registration

reload_modules()
cur_shape.dijkstra_to_sample()

cur_mesh = cur_shape.o3d_mesh()
cur_pcd = cur_shape.o3d_pcd("b")
ref_mesh = cur_shape.o3d_ref_mesh_transformed()

draw([
    cur_mesh,
    cur_pcd,
    ref_mesh,
])

WebVisualizer(window_uid='window_1')

### Check normals

In [None]:
# Check normals rotation conservation

pcd_ref = cur_shape.reference.o3d_pcd()
line_set1 = get_normal_lineset(
    pcd_ref, 2*cur_shape.reference.normals[cur_shape.reference.sample_idx])

ref_msh = cur_shape.reference.o3d_mesh()
Tref_msh = cur_shape.o3d_ref_mesh_transformed()

# Only rotation
Tref_msh2 = ou.numpy_to_o3d_mesh(
    vertices=cur_shape.reference.vertexes@cur_shape.Rotref.T,
    triangles=cur_shape.reference.faces
)
Tref_msh2.compute_vertex_normals()

# Only rotation
Rotpcd_ref = ou.numpy_to_o3d_pcd(
    cur_shape.reference.sample @ cur_shape.Rotref.T,
    normals=cur_shape.reference.normals[cur_shape.reference.sample_idx] @ cur_shape.Rotref.T
)
line_set2 = get_normal_lineset(Rotpcd_ref, 2*np.asarray(Rotpcd_ref.normals))


draw([
    line_set1,
    line_set2,
    ref_msh,
#     Tref_msh,
    Tref_msh2
])




In [None]:
# Comparison of multiple normals

# Normals without applying rotation (should not match)
pcd_ref = ou.numpy_to_o3d_pcd(
    su.transform_cloud(cur_shape.Tref, cur_shape.reference.sample),
    normals=cur_shape.reference.normals[cur_shape.reference.sample_idx]
)
line_set1 = get_normal_lineset(pcd_ref, 2*np.asarray(pcd_ref.normals))

# Normals after rotation
pcd_ref2 = ou.numpy_to_o3d_pcd(
    su.transform_cloud(cur_shape.Tref, cur_shape.reference.sample),
    normals=cur_shape.reference.normals[
        cur_shape.reference.sample_idx] @ (cur_shape.Rotref.T)
)
line_set2 = get_normal_lineset(pcd_ref2, 2*np.asarray(pcd_ref2.normals))

ref_msh = ou.numpy_to_o3d_mesh(
    vertices=su.transform_cloud(cur_shape.Tref, cur_shape.reference.vertexes),
    triangles=cur_shape.reference.faces
)
ref_msh.compute_vertex_normals()

# Recomputation of normals
pcd_ref3 = ou.numpy_to_o3d_pcd(
    su.transform_cloud(cur_shape.Tref, cur_shape.reference.sample),
    normals=-np.asarray(ref_msh.vertex_normals)[cur_shape.reference.sample_idx]
)
line_set3 = get_normal_lineset(pcd_ref3, 2*np.asarray(pcd_ref3.normals))

# We should expect line_set2 and line_set3 to be good, and line_set1 to be bad
draw([
    pcd_ref,
    line_set1,
    line_set2,
    line_set3,
    ref_msh,
])

In [168]:
pcd_sample = ou.numpy_to_o3d_pcd(
    cur_shape.sample,
    normals=cur_shape.normals[cur_shape.sample_idx]
)
# pcd_sample.estimate_normals()


In [None]:
%%time

# pcd_sample = cur_shape.o3d_pcd(normals=cur_shape.normals[cur_shape.sample_idx])

msh_sample = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
    pcd_sample, 
#     depth=15, 
#     width=0, 
#     scale=1.1, 
    linear_fit=False)[0]
msh_sample.compute_vertex_normals()
draw([pcd_sample, msh_sample, cur_mesh])

In [None]:
%%time

reload_modules()

bpa_mesh = ou.ball_pivoting_mesh(pcd_sample)
draw([pcd_sample, bpa_mesh, cur_mesh])

### Compute statistics using registred shapes before

In [46]:
reload_modules()

shmodel = sh.SSM(
    shapes=[ref_shape] + all_shapes,
#         all_shapes[idx] for idx in sorted(set(range(len(all_shapes))).difference([2, 9]))], 
    reference=ref_shape
)
shmodel.compute_pca(remove_outliers=True)
shmodel.pca.explained_variance_ratio_.cumsum()

array([0.4346346 , 0.69887584, 0.83488226, 0.89439174, 0.93723006,
       0.96208014, 0.98487608, 1.        , 1.        ])

In [None]:
# Visualize PCA

idx = 0

pca0 = shmodel.get_component(idx)
normal = shmodel.pca_normals[idx]

pcd = ou.get_o3d_pcd_colored(points=pca0, normals=normal)
lineset = get_normal_lineset(pcd, 2 * normal)

msh_sample = random.choice(shmodel.shapes).create_mesh_from_sample_faces(pca0)
msh_sample_smooth = msh_sample.filter_smooth_simple()

draw([
    pcd,
    msh_sample,
    msh_sample_smooth,
    lineset
])

In [None]:
# Visualize random generated mesh

(mesh, pcd), feats = shmodel.random_mesh_pca(scalar=2, sampling_mode="uniform")
mesh2 = mesh.filter_smooth_simple()

draw([
#     mesh,
    mesh2,
    pcd
])

In [45]:
# Multiple shapes correspondance

idx1, idx2 = random.sample(range(len(shmodel.all_samples)), 2)
# idx1 = 0
print(idx1, idx2)
samp1, samp2 = shmodel.all_samples[[idx1, idx2]]
norms1, norms2 = shmodel.shapes[idx1].normals, shmodel.shapes[idx1].normals

pcd1 = ou.get_o3d_pcd_colored(samp1, 'g', normals=norms1)
pcd2 = ou.get_o3d_pcd_colored(samp2, 'r', normals=norms2)


line_set = get_lineset_matches(pcd1, pcd2)

draw([pcd1, pcd2, line_set])

8 1


WebVisualizer(window_uid='window_19')

In [None]:
# Check normals
i = 1
shp_idx = i
print(i)
shp0 = shmodel.shapes[shp_idx]
cur_msh = shp0.o3d_mesh()
ref_msh = shp0.o3d_ref_mesh_transformed()
cur_pcd = shp0.o3d_pcd()
ref_pcd = shp0.o3d_ref_pcd_transformed('r')
# msh = ou.numpy_to_o3d_mesh(
#     vertices=shp0.vertexes, triangles=shp0.faces
# )
# msh.compute_vertex_normals()
pcd = ou.numpy_to_o3d_pcd(shp0.sample)
lineset = get_normal_lineset(pcd, shmodel.all_sample_normals[shp_idx])

draw([
#     pcd,
#     msh,
    cur_pcd,
    ref_pcd,
    lineset
])
i += 1

In [67]:
# Link between SVD and PCA

all_sample_vector = shmodel.all_samples.reshape(len(shmodel), -1)
mean_sample = all_sample_vector.mean(0)
X = all_sample_vector - mean_sample
U, S, Vt = np.linalg.svd(X)
print("Principal components", 
      np.abs((U.T @ X)[:len(shmodel)] / S[:, np.newaxis] - Vt[:len(shmodel)]).sum(1) < 1e-5)
print("Principal components", 
      (np.abs((U.T @ X)[:len(shmodel)] / S[:, np.newaxis] - shmodel.pca.components_).sum(1) < 1e-5) |
     (np.abs((U.T @ X)[:len(shmodel)] / S[:, np.newaxis] + shmodel.pca.components_).sum(1) < 1e-5))
print("Singular values", 
      np.abs(S**2/(len(shmodel) - 1) - shmodel.pca.explained_variance_).sum() < 1e-5)

Principal components [ True  True  True  True  True  True  True  True  True  True False]
Principal components [ True  True  True  True  True  True  True  True  True  True False]
Singular values True


In [160]:
# Compare the sample submesh for each shape  
nb_same_faces = np.zeros((len(shmodel), len(shmodel)))

for idx in tqdm(range(len(shmodel))):
    shmodel.shapes[idx].compute_sample_faces()

for idx1 in tqdm(range(len(shmodel))):
    for idx2 in range(idx1 + 1, len(shmodel)):
        in_the_other = []
        for face in shmodel.shapes[idx1].faces_sample:
            if (np.abs(face - shmodel.shapes[idx2].faces_sample).sum(1) == 0).sum() > 0:
                in_the_other.append(face)
        nb_same_faces[idx1, idx2] = len(in_the_other)

print(nb_same_faces)

100%|██████████| 9/9 [00:08<00:00,  1.08it/s]
100%|██████████| 9/9 [00:02<00:00,  3.36it/s]

[[  0. 276. 219. 267. 315.  99. 190. 334. 230.]
 [  0.   0. 101. 129. 168.  65.  78. 119. 100.]
 [  0.   0.   0.  92. 126.  38.  77. 106.  93.]
 [  0.   0.   0.   0. 164.  43. 102. 132. 103.]
 [  0.   0.   0.   0.   0.  67. 108. 191. 122.]
 [  0.   0.   0.   0.   0.   0.  53.  55.  48.]
 [  0.   0.   0.   0.   0.   0.   0. 126.  71.]
 [  0.   0.   0.   0.   0.   0.   0.   0. 113.]
 [  0.   0.   0.   0.   0.   0.   0.   0.   0.]]





## 2D test ICP

In [19]:
def get_ellipse_fn(center, matrix, radius):
    def fn(*x):
        W, L, H = x[0].shape
        Z = np.zeros_like(x[0])
        for i in range(W):
            for j in range(L):
                for k in range(H):
                    coord = np.array([x[0][i, j, k], x[1][i, j, k], x[2][i, j, k]]) - center
                    coord = coord[:, np.newaxis]
                    Z[i, j, k] = np.sqrt( coord.T @ matrix @ coord) - radius
        return Z
    return fn

def get_ellipsoid_points(center, matrix, radius, shape=(40, 40, 40), eps=1e-1):

    XX, YY, ZZ = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), np.arange(shape[2]))
    fn_ellipse = get_ellipse_fn(center, matrix, radius)
    lvset = fn_ellipse(XX, YY, ZZ)
    mask = np.zeros_like(lvset)
    mask[(-eps < lvset) & (lvset < eps)] = 1
    Xs, Ys, Zs = np.where(mask)
    return np.concatenate((Xs[:, np.newaxis], Ys[:, np.newaxis], Zs[:, np.newaxis]), axis=1)


In [161]:
%%time
shape = (70, 70, 70)
eps=1e-3

A1= np.array([
    [6, 2, 1],
    [2, 2, 0],
    [1, 0, 1]
])
center1 = np.array([15, 10, 10])
radius1 = 5

ell1 = get_ellipsoid_points(center1, A1, radius1, shape=shape, eps=eps)

A2 = np.array([
    [6, 0, 0],
    [0, 2, 0],
    [0, 0, 1]
])
center2 = np.array([10, 10, 10])
radius2 = 5

ell2 = get_ellipsoid_points(center2, A2, radius2, shape=shape, eps=eps)

CPU times: user 2.93 s, sys: 10.1 ms, total: 2.94 s
Wall time: 2.94 s


In [162]:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

ax.scatter(*ell1.T, label='ell1')
ax.scatter(*ell2.T, label='ell2')
ax.legend()

<IPython.core.display.Javascript object>

<matplotlib.legend.Legend at 0x7f78f943c7f0>

In [91]:
%%timeit
reload_modules()
T, errs, iters = icp.icp(ell1, ell2, n_points=20)


11.7 ms ± 108 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [151]:
all_errs = []
all_errs2 = []
all_n_points = range(1, len(ell1))
T_opt = np.array([
    [1, 0, 0, 0],
    [0, 1, 0, -5],
    [0, 0, 1, 0],
    [0, 0, 0, 1]
])

for n_pt in tqdm(all_n_points):
    T, errs ,iters = icp.icp(ell1, ell2, n_points=n_pt, max_iterations=1000)
    all_errs.append(np.abs(errs).max())
    all_errs2.append(np.max(T - T_opt))

100%|██████████| 183/183 [00:12<00:00, 15.09it/s]


In [152]:
fig, axs = plt.subplots(1, 2)
axs[0].plot(all_errs)
axs[1].plot(all_errs2)

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x7f78f8dc4ac0>]

In [121]:
print(T)
print(np.abs(errs).mean())
print(iters)

[[ 1.00000000e+00 -9.90796275e-07 -1.41774147e-06  2.90207983e-05]
 [ 9.89326211e-07  9.99999463e-01 -1.03653054e-03 -4.98964221e+00]
 [ 1.41876770e-06  1.03653054e-03  9.99999463e-01 -2.10099371e-02]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  1.00000000e+00]]
0.011818425868727712
7


In [163]:
reload_modules()
T, errs, iters = icp.icp(ell1, ell2)
print(T)
print(np.abs(errs).max())
print(iters)

[[ 0.89959054 -0.31147948 -0.30613296  8.77730865]
 [ 0.37209864  0.91361427  0.16386447 -9.0799491 ]
 [ 0.22864702 -0.26132259  0.93777985  2.30011588]
 [ 0.          0.          0.          1.        ]]
1.937060309529838
15


In [164]:
Tell2 = ((T @ np.hstack((ell2, np.ones((ell2.shape[0], 1)))).T).T)[:, :-1]
Tell1 = ((T @ np.hstack((ell1, np.ones((ell1.shape[0], 1)))).T).T)[:, :-1]

fig = plt.figure()
ax1 = fig.add_subplot(111, projection='3d')

ax1.scatter(*ell1.T, label='ell1')
ax1.scatter(*Tell1.T, label='Tell1')
ax1.scatter(*ell2.T, label='ell2')
ax1.legend()

<IPython.core.display.Javascript object>

<matplotlib.legend.Legend at 0x7f78f81abaf0>

# II) Inference