In [1]:
import numpy as np
from toolbox.load_data import read_off, read_npy
from toolbox.rescale import rescale
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from mesh_utils import vectorized_is_point_inside_mesh

In [2]:
N = 100
t = np.linspace(0, 1, N)
space = np.stack(np.meshgrid(t, t, t), axis=-1).reshape(N**3, -1)

In [3]:
path_to_data = "../../data/"

fignames = ["duck", "torus"]
fig_data = []
for figname in fignames:
    try :
        fig_data.append(read_npy(path_to_data, figname))
    except FileNotFoundError:
        arr_verts, arr_faces = read_off(path_to_data, figname)
        arr_verts = rescale(arr_verts, 0.03, 0.97)
        fig_data.append(vectorized_is_point_inside_mesh(space, arr_verts, arr_faces).reshape(N, N, N).astype(np.float32))

In [None]:
fig = make_subplots(rows=1, cols=len(fig_data), specs=[[{"type": "scatter3d"} for _ in fig_data]])
for i, fig_arr in enumerate(fig_data):
    points_inside = (fig_arr.reshape(N**3, -1) > 0.1).ravel()
    x = space[points_inside][:, 0]
    y = space[points_inside][:, 1]
    z = space[points_inside][:, 2]
    fig.add_trace(
        go.Scatter3d(x=x, y=y, z=z,
                                    mode='markers',
                                    ),
        row=1, col=i+1
    )
fig.show()

In [15]:
q = 5
t = np.linspace(0,1,q)
W = np.array([[s,1-s] for s in t])

In [16]:
from toolbox.blur_kernel import imgaussian
from toolbox.apply_3d_func import apply_3d_func
mu = N/25
blur = lambda x: imgaussian(x, mu, mu*50)
Kv = lambda x: apply_3d_func(blur, x)

In [17]:
from tqdm import tqdm

def convolutional_barycenter(mus, alphas, area_weights, kernel, kernel_transpose, entropy_limit):
    n_iter = 1500
    tol = 1e-3
    v = np.ones(mus.shape)
    alphas = alphas / sum(alphas)
    if area_weights is None:
        area_weights = np.ones(mus.shape[1])
    if kernel_transpose is None:
        kernel_transpose = kernel
    barycenter = np.ones(mus.shape[1])
    for i in range(n_iter):
        old_barycenter = barycenter
        w = mus / (kernel_transpose(v * area_weights))
        d = v * kernel(w * area_weights)
        d[d<1e-100] = 1e-100
        barycenter = np.exp(np.sum(alphas.reshape(-1, 1) * np.log(d), axis=0))
        #entropy = -np.sum(area_weights*(barycenter*np.log(barycenter)))
        v = v*barycenter/d
        change = np.sum(np.abs(old_barycenter-barycenter) * area_weights)
        print(change)
        if np.isnan(change) or (i > 2 and change < tol):
            return barycenter
    return barycenter

In [None]:
#from toolbox.convolutional_barycenter import convolutional_barycenter

bar = []
for w in tqdm(W):
    w = w/sum(w)
    entropy_limit = None
    Hv = np.array([dist.flatten()/dist.sum() for dist in fig_data])
    B = convolutional_barycenter(Hv, w, None, Kv, None, entropy_limit)
    B = B.reshape(N, N, N)
    B = B/max(B.flatten())
    bar.append(B)

In [None]:
bary = bar[0]
points_inside = (bary.reshape(N**3, -1) > 0.1).ravel()
x = space[points_inside][:, 0]
y = space[points_inside][:, 1]
z = space[points_inside][:, 2]
fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z,
                           mode='markers',
                        )])
fig.show()

In [None]:
fig = make_subplots(rows=q+2, cols=1, specs=[[{"type": "scatter3d"}] for _ in range(q+2)])
for i in range(2, q+2):
    try:
        bary = bar[i-2]
        points_inside = (bary.reshape(N**3, -1) > 0.1).ravel()
        x = space[points_inside][:, 0]
        y = space[points_inside][:, 1]
        z = space[points_inside][:, 2]
        fig.add_trace(
            go.Scatter3d(x=x, y=y, z=z,
                                        mode='markers',
                                        ),
            row=i, col=1
        )
    except:
        pass
for i, fig_arr in zip([q+2, 1], fig_data):
    points_inside = (fig_arr.reshape(N**3, -1) > 0.1).ravel()
    x = space[points_inside][:, 0]
    y = space[points_inside][:, 1]
    z = space[points_inside][:, 2]
    fig.add_trace(
        go.Scatter3d(x=x, y=y, z=z,
                                    mode='markers',
                                    ),
        row=i, col=1
    )
fig.update_layout(height=2000, width=400)
fig.show()