In [1]:
import time
import functools
import numpy as np
import jax
import jax.numpy as jnp
from jax import grad, jacfwd, vmap, jit

import meshplot as mp

In [3]:
DATA = 'meshes/'
examplemesh = 'sphere2.obj'
#examplemesh = 'testmesh.obj'

def objloader(folder, fname):
    V = []
    F = []
    with open(folder + fname, 'r') as f:
        lines = f.readlines()
        for l in lines:
            token = l.split(' ')[0]
            if token == 'v':
                V.append(jnp.array([float(v) for v in l.split(' ')[1:]]))
            if token == 'f':
                F.append(jnp.array([int(f.split('/')[0]) - 1 for f in l.split(' ')[1:]]))
    #print(F)
    V = jnp.array(V)
    F = jnp.array(F, dtype=jnp.int32)
    
    return V, F

def to_numpy_array(X):
    Np_array = np.zeros(X.shape)
    for n in range(X.shape[0]):
        Np_array[n] = X[n]
    return Np_array


def objloader_np(folder, fname):
    V = []
    F = []
    with open(folder + fname, 'r') as f:
        lines = f.readlines()
        for l in lines:
            token = l.split(' ')[0]
            if token == 'v':
                V.append(jnp.array([float(v) for v in l.split(' ')[1:]]))
            if token == 'f':
                F.append(jnp.array([int(f.split('/')[0]) - 1 for f in l.split(' ')[1:]]))
    #print(F)
    V = np.array(V)
    F = np.array(F, dtype=jnp.int32)
    
    return V, F

def get_edge(V, e0, e1):
    return V[e1] - V[e0]

def normalize(v):
    return v / jnp.linalg.norm(v)

def compute_normal_normalized(e0, e1):
    return normalize(jnp.cross(e0, e1))

def compute_normal(e0, e1):
    return jnp.cross(e0, e1)

compute_normal_vmap = vmap(jnp.cross)

V, F = objloader(DATA, examplemesh)
Nfs = [[fi for fi, f in enumerate(F) if vi in f] for vi in jnp.arange(V.shape[0])]
VF = np.zeros((V.shape[0], F.shape[0]),dtype=np.int32)

for vi in np.arange(V.shape[0]):
    Nf = Nfs[vi]

    for fi in Nf:
        VF[vi, fi] = 1
        
VF = jnp.array(VF)

V_np, F_np = objloader_np(DATA, examplemesh)

%matplotlib notebook
mp.jupyter()
print(V.shape)
print(F.shape)
p = mp.plot(V_np, F_np)
#p.save("test.html")

(162, 3)
(320, 3)


Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0, 0.0,…

## Coding 10.1

### Compute the uniformly weighted vertex normals given by

$
N_U = \sum\limits_i N_i
$

In [4]:
def compute_uniformly_weighted_normals(V, F):
    Nu = []
    for vi, v in enumerate(V):
        # Get neighbouring faces
        Nf = [fi for fi, f in enumerate(F) if vi in f]
        # Compute normals
        Nv = jnp.zeros(3)
        for fi in Nf:
            f = F[fi]
            Nv += compute_normal_normalized(V[f[1]] - V[f[0]], V[f[2]] - V[f[0]])
        Nu.append(normalize(Nv))
    return jnp.array(Nu)
        
N_U = compute_uniformly_weighted_normals(V, F)

## Coding 10.2

### Compute the vertex normal using face area weights

$
N_\nu = \sum\limits_i N_i \cdot \frac{A_i}{\sum\limits_j A_j} = \frac{1}{\sum\limits_j A_j}\sum\limits_i N_i \cdot A_i
$

In [5]:
def compute_area_weighted_normals(V, F):
    Nu = []
    for vi, v in enumerate(V):
        # Get neighbouring faces
        Nf = [fi for fi, f in enumerate(F) if vi in f]
        # Compute normals
        Nv = jnp.zeros(3)
        A = 0
        for fi in Nf:
            f = F[fi]
            N = compute_normal(V[f[1]] - V[f[0]], V[f[2]] - V[f[0]])
            Ai = jnp.linalg.norm(N) * 0.5
            Nv += normalize(N) * Ai
            A += Ai
        Nv /= A
        Nu.append(normalize(Nv))
    return jnp.array(Nu)

N_v = compute_area_weighted_normals(V, F)

## Coding 10.3

### Compute the vertex normal using tip angle

$
N_\nu = \sum\limits_i \theta_i N_i
$

In [6]:
def compute_tip_angle_weighted_normals(V, F):
    Nu = []
    for vi, v in enumerate(V):
        # Get neighbouring faces
        Nf = [fi for fi, f in enumerate(F) if vi in f]
        # Compute normals
        Nv = jnp.zeros(3)
        for fi in Nf:
            f = F[fi]
            N = compute_normal(V[f[1]] - V[f[0]], V[f[2]] - V[f[0]])
            # Figure out "where" the vertex is located
            other_nodes = [fv for fv in f if fv != vi]
            v0 = V[other_nodes[0]] - V[vi]
            v1 = V[other_nodes[1]] - V[vi]
            theta_i = jnp.arccos(jnp.dot(v0, v1) / (jnp.linalg.norm(v0) * jnp.linalg.norm(v1)))
            Nv += normalize(N) * theta_i
        Nu.append(normalize(Nv))
    return jnp.array(Nu)

N_theta = compute_tip_angle_weighted_normals(V, F)

In [7]:
N_U_np = to_numpy_array(N_U)
N_v_np = to_numpy_array(N_v)
N_theta_np = to_numpy_array(N_theta)

d = mp.subplot(V_np,F_np, c = N_U_np, s=[1, 3, 0])
mp.subplot(V_np,F_np, c = N_v_np, s=[1, 3, 1], data=d)
mp.subplot(V_np,F_np, c = N_theta_np, s=[1, 3, 2], data=d)


HBox(children=(Output(), Output(), Output()))

## Coding 10.4

### Compute the the mean curvature

$
\Delta f = \nabla_{p_i} A = \frac{1}{2}\sum\limits_j (\cot \alpha_j +\cot \beta_j)(p_j - p_i) 
$

In [8]:
def compute_mean_curvature(V, F):
    nabla_P = []
    for p_i, v in enumerate(V):
        # Get neighbouring faces
        Nf = [fi for fi, f in enumerate(F) if p_i in f]
        # Get all other points
        P_j = jnp.unique(jnp.array([[pj for pj in F[f] if pj != p_i] for f in Nf]).flatten())
        # Compute normals
        nabla_p = jnp.zeros(3)
        for p_j in P_j:
            # Assume all vertices in P_j are part of exactly two faces
            f0, f1 = [f for f in Nf if p_j in F[f]]
            p_k = [fi for fi in F[f0] if fi != p_j and fi != p_i][0]
            p_m = [fi for fi in F[f1] if fi != p_j and fi != p_i][0]
            
            p_ki = V[p_i] - V[p_k]
            p_kj = V[p_j] - V[p_k]
            
            p_mi = V[p_i] - V[p_m]
            p_mj = V[p_j] - V[p_m]
            
            alpha = jnp.arccos(jnp.dot(p_ki, p_kj) / (jnp.linalg.norm(p_ki) * jnp.linalg.norm(p_kj)))
            beta  = jnp.arccos(jnp.dot(p_mi, p_mj) / (jnp.linalg.norm(p_mi) * jnp.linalg.norm(p_mj)))

            nabla_p += (1./jnp.tan(alpha) + 1./jnp.tan(beta)) * (V[p_i] - V[p_j])
            
        nabla_p *= 0.5
        nabla_P.append(nabla_p)
    return jnp.array(nabla_P)

def compute_mean_curvature_jax_old(V, F):
    A = jnp.zeros(V.shape[0])
    for vi, v in enumerate(V):
        # Get neighbouring faces
        #Nf = [fi for fi, f in enumerate(F) if vi in f]
        Nf = Nfs[vi]
        # Compute normals
        for fi in Nf:
            f = F[fi]
            #N = compute_normal(V[f[1]] - V[f[0]], V[f[2]] - V[f[0]])
            N = jnp.cross(V[f[1]] - V[f[0]], V[f[2]] - V[f[0]])
            A = jax.ops.index_add(A, vi, jnp.linalg.norm(N) * 0.5)
    return A

def compute_mean_curvature_jax(V, F):
    A = [None for i in range(V.shape[0])]#jnp.zeros(V.shape[0])
    for vi, v in enumerate(V):
        #Nf = [fi for fi, f in enumerate(F) if vi in f]
        Nf = Nfs[vi]
        ff = F[Nf]
        v1 = V[ff[:,1]] - V[ff[:,0]]
        v2 = V[ff[:,2]] - V[ff[:,0]]
        NN = vmap(jnp.cross)(v1, v2)
        A[vi] = jnp.sum(jnp.linalg.norm(NN, axis=1) * 0.5)

    return jnp.array(A)

def compute_mean_curvature_jax_new(V, F):
    #A = [None for i in range(V.shape[0])]
    #import numpy as nonp
    #A = nonp.empty(V.shape[0])
    #@partial(jnp.vectorize, signature='(n,m),(m)->(n)')
    def f(VFrow): #, V, F, Nfs
        #print(vi)
        #Nf = Nfs[vi]
        Nf = jnp.where(VFrow)
        ff = F[Nf]
        v1 = V[ff[:,1]] - V[ff[:,0]]
        v2 = V[ff[:,2]] - V[ff[:,0]]
        NN = compute_normal_vmap(v1, v2)
        return jnp.sum(jnp.linalg.norm(NN, axis=1) * 0.5)
        
    #h = jnp.vectorize(f, signature="(m)->()") # I can never figure out the signature
    # though, it complains about getting a tracer
    #A = jnp.apply_along_axis(f, 1, VF) # same as above, complains about getting a tracer
    A = jnp.array([f(VF[vi]) for vi in jnp.arange(V.shape[0])]) # I don't want to do this
    #h = vmap(f) #???

    #g = functools.partial(f, V=V, F=F, Nfs=Nfs)
    # This is a probably wrong, other way of doing the below...
    #g = lambda vi: f(vi, V, F, Nfs) # nope
    
    #h = vmap(g) # I have ...
    #A = h(Nfs)  # ... tried a bit ...
    #A = h(VF)   # ... of things.

    return A

now = time.time()
analytical = compute_mean_curvature(V, F)
then = time.time()
print("Analytical time spent:", then - now)
now = time.time()
dcompute_mean_curvature_jax_old = jacfwd(compute_mean_curvature_jax_old)
autodiff_old = dcompute_mean_curvature_jax_old(V, F)
autodiff_old = jnp.array([autodiff_old[i, i, :] for i in range(V.shape[0])])
then = time.time()
print("Autodiff time spent old:", then - now)
now = time.time()
dcompute_mean_curvature_jax = jacfwd(compute_mean_curvature_jax)
autodiff = dcompute_mean_curvature_jax(V, F)
autodiff = jnp.array([autodiff[i, i, :] for i in range(V.shape[0])])
then = time.time()
print("Autodiff time spent new:", then - now)
now = time.time()
dcompute_mean_curvature_jax_new = jacfwd(compute_mean_curvature_jax_new)
autodiff_new = dcompute_mean_curvature_jax_new(V, F)
autodiff_new = jnp.array([autodiff_new[i, i, :] for i in range(V.shape[0])])
then = time.time()
print("Autodiff time spent newer:", then - now)

Analytical time spent: 17.873614072799683
Autodiff time spent old: 14.788041830062866
Autodiff time spent new: 3.3326003551483154
Autodiff time spent newer: 3.015868663787842


In [62]:
#print(jnp.linalg.norm(analytical,axis = 1))
#print(autodiff_old)
#print(autodiff)

In [63]:
#print(analytical)
#print(autodiff)

In [14]:
analytical_np = to_numpy_array(analytical)
autodiff_old_np = to_numpy_array(autodiff_old)
autodiff_new_np = to_numpy_array(autodiff_new)

scale = 3

plot_curvature = mp.subplot(V_np,F_np, c = scale * analytical_np, s=[1, 2, 0])
mp.subplot(V_np,F_np, c = scale * autodiff_new_np, s=[1, 2, 1], data=plot_curvature)

HBox(children=(Output(), Output()))

In [15]:
plot_curvature2 = mp.subplot(V_np,F_np, c = np.linalg.norm(analytical_np,axis = 1), s=[1, 2, 0])
mp.subplot(V_np,F_np, c = np.linalg.norm(autodiff_new_np,axis = 1), s=[1, 2, 1], data=plot_curvature2)

HBox(children=(Output(), Output()))