In [17]:
import numpy as np
from src.wasserstein_barycenters_3d import convolutional_barycenter_3d
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pyvista as pv

In [2]:
shape1 = np.load('data/numpy/duck.npy')
shape2 = np.load('data/numpy/torus.npy')
initial_shapes = np.concatenate((shape1[np.newaxis, :, :, :], shape2[np.newaxis, :, :, :]), axis=0)

In [3]:
N = 100
t = np.linspace(0, 1, N)
space = np.stack(np.meshgrid(t, t, t), axis=-1).reshape(N**3, -1)
space3d = np.stack(np.meshgrid(t, t, t), axis=-1)

In [4]:
n_steps = 5
t = np.linspace(0,1,n_steps)
weights = np.array([[s,1-s] for s in t])

In [5]:
sigma = N//25
tol = 1e-4

In [6]:
# points clouds to distributions 
for i in range(len(initial_shapes)):
    initial_shapes[i] = initial_shapes[i]/np.sum(initial_shapes[i])

In [7]:
barycenters = []

for alpha in weights:
    barycenter = convolutional_barycenter_3d(initial_shapes, alpha, sigma=sigma, tol=tol)
    barycenter = barycenter/max(barycenter.flatten())
    barycenters.append(barycenter)

  1%|          | 1/100 [00:02<03:48,  2.31s/it]
 72%|███████▏  | 72/100 [01:26<00:33,  1.20s/it]
 81%|████████  | 81/100 [01:45<00:24,  1.30s/it]
  v[i] = v[i] * barycenter / d[i]
  d[i] = v[i] * kernel(area_weights * w[i])
 17%|█▋        | 17/100 [00:21<01:43,  1.25s/it]
  1%|          | 1/100 [00:02<04:26,  2.69s/it]


In [8]:
def find_neighbours(cloud,i,j,k):
    """
    for a point and a point cloud, computes number of neighbours (no diagonals)
    """
    res = 0
    res += cloud[i-1,j,k]
    res += cloud[i+1,j,k]
    res += cloud[i,j-1,k]
    res += cloud[i,j+1,k]
    res += cloud[i,j,k-1]
    res += cloud[i,j,k+1]
    return res

In [9]:
def delete_inside_points(cloud):
    """
    From a point cloud of shape (N,N,N), returns cloud w/o points inside, meaning ones with 6 neighbours
    """
    surface = cloud.copy()
    N = cloud.shape[0]
    indexes = []
    for i in range(1,N-1):
        for j in range(1,N-1):
            for k in range(1,N-1):
                neighbours = find_neighbours(cloud,i,j,k)
                if neighbours==6:
                    indexes.append((i,j,k))
    for i,j,k in indexes:
        surface[i,j,k] = False
    return surface

In [10]:
bary = barycenters[2]
cloud0 = bary>0.5
cloud0 = delete_inside_points(cloud0)
#points_inside = (bary.reshape(N**3, -1) > 0.1).ravel()
points_inside = cloud0.reshape(N**3, -1).ravel()
x = space[points_inside][:, 0]
y = space[points_inside][:, 1]
z = space[points_inside][:, 2]
fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z,
                           mode='markers',
                        )])
fig.show()

In [11]:
points_inside = (barycenters[0] > 0.5).flatten()
x = space[points_inside][:, 0]
y = space[points_inside][:, 1]
z = space[points_inside][:, 2]
fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z, mode='markers')])
fig.show()

In [16]:
bary = barycenters[0]
cloud0 = bary>0.5
# get point cloud without inside points
cloud0 = delete_inside_points(cloud0)
#create mesh
points_inside = cloud0.flatten()
#points_inside = (bary>0.5).flatten()
#points_inside = cloud0.reshape(N**3, -1).ravel()
x = space[points_inside][:, 0]
y = space[points_inside][:, 1]
z = space[points_inside][:, 2]
points = np.array([x,y,z]).T
point_cloud = pv.PolyData(points)
mesh = point_cloud.reconstruct_surface(progress_bar=True)

#convert pyvista mesh to plotly
arr_faces = mesh.faces.reshape(-1, 4)[:, 1:]
arr_verts = mesh.points

fig = go.Figure(data=[
    go.Mesh3d(
        x=arr_verts[:,0],
        y=arr_verts[:,1],
        z=arr_verts[:,2],
        i = arr_faces[:,0],
        j = arr_faces[:,1],
        k = arr_faces[:,2],
        color='cyan', 
    ),
])
fig.show()

Reconstructing surface: 100%|██████████[00:00<00:00]


In [None]:
for i,bary in enumerate(barycenters):
    cloud0 = bary>0.5
    # get point cloud without inside points
    cloud0 = delete_inside_points(cloud0)
    #create mesh
    points_inside = cloud0.flatten()
    #points_inside = (bary>0.5).flatten()
    #points_inside = cloud0.reshape(N**3, -1).ravel()
    x = space[points_inside][:, 0]
    y = space[points_inside][:, 1]
    z = space[points_inside][:, 2]
    points = np.array([x,y,z]).T
    point_cloud = pv.PolyData(points)
    mesh = point_cloud.reconstruct_surface(progress_bar=True)

    #convert pyvista mesh to plotly
    arr_faces = mesh.faces.reshape(-1, 4)[:, 1:]
    arr_verts = mesh.points

    fig = go.Figure(data=[
        go.Mesh3d(
            x=arr_verts[:,0],
            y=arr_verts[:,1],
            z=arr_verts[:,2],
            i = arr_faces[:,0],
            j = arr_faces[:,1],
            k = arr_faces[:,2],
            color='cyan', 
        ),
    ])
    fig.show()