In [1]:
import numpy as np
import ot
import trimesh
import open3d as o3d
import utils
from utils import GM
from tqdm import trange

#plotting
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from pathlib import Path

In [2]:
sqd = 1000 #parameter for mesh simplification
models_path = Path.cwd()/'data/models'

In [3]:
#LOAD ANIMALS dataset
animals = ['camel', 'cat', 'elephant', 'face', 'flamingo', 'head', 'horse', 'lion']
meshes_animals = []
names = []
for animal in animals:
    counter = 1
    pathlist = (models_path/f"{animal}-poses").rglob('*.obj')
    for path_to_obj in pathlist:
        filepath = str(path_to_obj)
        pcd = o3d.io.read_triangle_mesh(filepath)
        mesh = trimesh.Trimesh(vertices=pcd.vertices,faces = pcd.triangles)
        
        #simplifiy mesh
        mesh = mesh.simplify_quadric_decimation(sqd)
        
        #append mesh to list
        meshes_animals.append(mesh)
        names.append(animal + str(counter).zfill(2))
        counter += 1

In [4]:
meshes_animals_all = meshes_animals
meshes_animals = meshes_animals[:10]

In [None]:
%%time 
data_size = len(meshes_animals)

pairwise_distances = np.zeros((data_size, data_size))

for mesh1_i in range(data_size):
    
    mesh = meshes_animals[mesh1_i]
    X = GM(X=mesh.vertices,Tris=mesh.faces,mode="surface",gauge_mode="djikstra",squared=False)

    pairwise_distances[mesh1_i, mesh1_i] = 0
    
    for mesh2_i in range(mesh1_i+1, data_size):
        
        mesh = meshes_animals[mesh2_i]
        Y = GM(X=mesh.vertices,Tris=mesh.faces,mode="surface",gauge_mode="djikstra",squared=False)
        
        P,log = ot.gromov.gromov_wasserstein(X.g,Y.g,X.xi,Y.xi,log=True)

        pairwise_distances[mesh1_i, mesh2_i] = log["gw_dist"]
        pairwise_distances[mesh2_i, mesh1_i] = log["gw_dist"]

In [None]:
np.save('pairwise_distances', pairwise_distances)
pairwise_distances

In [None]:
width = 500
height = 500
ambient = 0.4

i = 0
mesh = meshes_animals[i]

fig = make_subplots(rows=1, cols=1,
               specs=[[{'type': 'scene'}] * 1],shared_xaxes=True)
fig.add_trace(
    go.Mesh3d(
    x=mesh.vertices[:,0],
    y=mesh.vertices[:,1],
    z=mesh.vertices[:,2],
    # i, j and k give the vertices of triangles
    i=mesh.faces[:,0],
    j=mesh.faces[:,1],
    k=mesh.faces[:,2],
    showscale=False,
    lighting=dict(ambient=ambient)
),
    row = 1, col = 1
)
fig.update_layout(showlegend=False,width=width,height=height)
fig.update_scenes(aspectmode='data')
fig = go.FigureWidget(fig)
fig.show()

In [None]:
%%time

#choose gm-spaces
i = 0
j = 1

mesh = meshes_animals[i]
X = GM(X=mesh.vertices,Tris=mesh.faces,mode="surface",gauge_mode="djikstra",squared=False)

mesh = meshes_animals[j]
Y = GM(X=mesh.vertices,Tris=mesh.faces,mode="surface",gauge_mode="djikstra",squared=False)


#compute GW Plan
P,log = ot.gromov.gromov_wasserstein(X.g,Y.g,X.xi,Y.xi,log=True)
print("GW Transport costs: {0}".format(log["gw_dist"]))

In [None]:
#colour-code X
cX = np.linalg.norm(X.X - np.min(X.X),axis=1)
#colour-code Y according to transport P
cY = (P.T / np.sum(P,axis=1)).dot(cX)

#plot
fig = make_subplots(rows=1, cols=2,
               specs=[[{'type': 'scene'}] * 2],shared_xaxes=True)
#plot X
fig.add_trace(
    go.Mesh3d(
    x=X.X[:,0],
    y=X.X[:,1],
    z=X.X[:,2],
    # Intensity of each vertex, which will be interpolated and color-coded
    intensity=cX,
    # i, j and k give the vertices of triangles
    i=X.Tris[:,0],
    j=X.Tris[:,1],
    k=X.Tris[:,2],
    showscale=False,
    lighting=dict(ambient=ambient)
),
    row = 1, col = 1
)

#plot Y
fig.add_trace(
    go.Mesh3d(
    x=Y.X[:,0],
    y=Y.X[:,1],
    z=Y.X[:,2],
    # Intensity of each vertex, which will be interpolated and color-coded
    intensity=cY,
    # i, j and k give the vertices of triangles
    i=Y.Tris[:,0],
    j=Y.Tris[:,1],
    k=Y.Tris[:,2],
    showscale=False,
    lighting=dict(ambient=ambient)
),
    row = 1, col = 2
)

fig.update_scenes(aspectmode='data')
fig.update_layout(showlegend=False,width=width,height=height)

fig = go.FigureWidget(fig)
fig.show()