In [1]:
import time
import functools
import numpy as np
import jax
import jax.numpy as jnp
from jax import grad, jacfwd, vmap, jit

In [2]:
DATA = 'meshes/'
examplemesh = 'sphere2.obj'
#examplemesh = 'testmesh.obj'

def objloader(folder, fname):
    V = []
    F = []
    with open(folder + fname, 'r') as f:
        lines = f.readlines()
        for l in lines:
            token = l.split(' ')[0]
            if token == 'v':
                V.append(jnp.array([float(v) for v in l.split(' ')[1:]]))
            if token == 'f':
                F.append(jnp.array([int(f.split('/')[0]) - 1 for f in l.split(' ')[1:]]))
    #print(F)
    V = jnp.array(V)
    F = jnp.array(F, dtype=jnp.int32)
    
    return V, F

def get_edge(V, e0, e1):
    return V[e1] - V[e0]

def normalize(v):
    return v / jnp.linalg.norm(v)

def compute_normal_normalized(e0, e1):
    return normalize(jnp.cross(e0, e1))

def compute_normal(e0, e1):
    return jnp.cross(e0, e1)

compute_normal_vmap = vmap(jnp.cross)

V, F = objloader(DATA, examplemesh)
Nfs = [[fi for fi, f in enumerate(F) if vi in f] for vi in jnp.arange(V.shape[0])]
VF = np.zeros((V.shape[0], F.shape[0]),dtype=np.int32)

for vi in np.arange(V.shape[0]):
    Nf = Nfs[vi]

    for fi in Nf:
        VF[vi, fi] = 1
        
VF = jnp.array(VF)

print(V.shape)
print(F.shape)




(162, 3)
(320, 3)


## Coding 10.1

### Compute the uniformly weighted vertex normals given by

$
N_U = \sum\limits_i N_i
$

In [3]:
def compute_uniformly_weighted_normals(V, F):
    Nu = []
    for vi, v in enumerate(V):
        # Get neighbouring faces
        Nf = [fi for fi, f in enumerate(F) if vi in f]
        # Compute normals
        Nv = jnp.zeros(3)
        for fi in Nf:
            f = F[fi]
            Nv += compute_normal_normalized(V[f[1]] - V[f[0]], V[f[2]] - V[f[0]])
        Nu.append(normalize(Nv))
    return jnp.array(Nu)
        
compute_uniformly_weighted_normals(V, F)

DeviceArray([[-2.2929834e-07, -1.0000000e+00,  0.0000000e+00],
             [ 7.2360772e-01, -4.4721887e-01,  5.2572542e-01],
             [-2.7638873e-01, -4.4722000e-01,  8.5064882e-01],
             [-8.9442575e-01, -4.4721648e-01,  0.0000000e+00],
             [-2.7638873e-01, -4.4722000e-01, -8.5064882e-01],
             [ 7.2360772e-01, -4.4721884e-01, -5.2572542e-01],
             [ 2.7638876e-01,  4.4722000e-01,  8.5064894e-01],
             [-7.2360772e-01,  4.4721887e-01,  5.2572542e-01],
             [-7.2360772e-01,  4.4721887e-01, -5.2572542e-01],
             [ 2.7638873e-01,  4.4722000e-01, -8.5064882e-01],
             [ 8.9442575e-01,  4.4721654e-01,  3.0170832e-09],
             [ 2.3231543e-07,  1.0000000e+00,  0.0000000e+00],
             [-2.3226970e-01, -6.5956312e-01,  7.1486163e-01],
             [-1.6245618e-01, -8.5065401e-01,  4.9999565e-01],
             [-7.8419305e-02, -9.6726358e-01,  2.4135327e-01],
             [ 2.0530809e-01, -9.6726346e-01,  1.491641

## Coding 10.2

### Compute the vertex normal using face area weights

$
N_\nu = \sum\limits_i N_i \cdot \frac{A_i}{\sum\limits_j A_j} = \frac{1}{\sum\limits_j A_j}\sum\limits_i N_i \cdot A_i
$

In [4]:
def compute_area_weighted_normals(V, F):
    Nu = []
    for vi, v in enumerate(V):
        # Get neighbouring faces
        Nf = [fi for fi, f in enumerate(F) if vi in f]
        # Compute normals
        Nv = jnp.zeros(3)
        A = 0
        for fi in Nf:
            f = F[fi]
            N = compute_normal(V[f[1]] - V[f[0]], V[f[2]] - V[f[0]])
            Ai = jnp.linalg.norm(N) * 0.5
            Nv += normalize(N) * Ai
            A += Ai
        Nv /= A
        Nu.append(normalize(Nv))
    return jnp.array(Nu)

compute_area_weighted_normals(V, F)

DeviceArray([[-8.5700719e-07, -1.0000000e+00,  0.0000000e+00],
             [ 7.2360784e-01, -4.4721809e-01,  5.2572584e-01],
             [-2.7638942e-01, -4.4721958e-01,  8.5064888e-01],
             [-8.9442569e-01, -4.4721660e-01,  0.0000000e+00],
             [-2.7638945e-01, -4.4721958e-01, -8.5064888e-01],
             [ 7.2360784e-01, -4.4721815e-01, -5.2572578e-01],
             [ 2.7638945e-01,  4.4721958e-01,  8.5064888e-01],
             [-7.2360790e-01,  4.4721809e-01,  5.2572578e-01],
             [-7.2360790e-01,  4.4721812e-01, -5.2572578e-01],
             [ 2.7638942e-01,  4.4721958e-01, -8.5064888e-01],
             [ 8.9442569e-01,  4.4721660e-01,  0.0000000e+00],
             [ 8.5700719e-07,  1.0000000e+00,  0.0000000e+00],
             [-2.2995479e-01, -6.6800559e-01,  7.0773530e-01],
             [-1.6245677e-01, -8.5065383e-01,  4.9999595e-01],
             [-8.1788935e-02, -9.6433735e-01,  2.5172204e-01],
             [ 2.1412830e-01, -9.6433723e-01,  1.555725

## Coding 10.3

### Compute the vertex normal using tip angle

$
N_\nu = \sum\limits_i \theta_i N_i
$

In [5]:
def compute_tip_angle_weighted_normals(V, F):
    Nu = []
    for vi, v in enumerate(V):
        # Get neighbouring faces
        Nf = [fi for fi, f in enumerate(F) if vi in f]
        # Compute normals
        Nv = jnp.zeros(3)
        for fi in Nf:
            f = F[fi]
            N = compute_normal(V[f[1]] - V[f[0]], V[f[2]] - V[f[0]])
            # Figure out "where" the vertex is located
            other_nodes = [fv for fv in f if fv != vi]
            v0 = V[other_nodes[0]] - V[vi]
            v1 = V[other_nodes[1]] - V[vi]
            theta_i = jnp.arccos(jnp.dot(v0, v1) / (jnp.linalg.norm(v0) * jnp.linalg.norm(v1)))
            Nv += normalize(N) * theta_i
        Nu.append(normalize(Nv))
    return jnp.array(Nu)

compute_tip_angle_weighted_normals(V, F)

DeviceArray([[-7.1248417e-07, -1.0000000e+00,  2.4234157e-09],
             [ 7.2360790e-01, -4.4721878e-01,  5.2572525e-01],
             [-2.7638835e-01, -4.4722021e-01,  8.5064900e-01],
             [-8.9442581e-01, -4.4721639e-01,  0.0000000e+00],
             [-2.7638835e-01, -4.4722021e-01, -8.5064900e-01],
             [ 7.2360778e-01, -4.4721875e-01, -5.2572525e-01],
             [ 2.7638832e-01,  4.4722012e-01,  8.5064900e-01],
             [-7.2360790e-01,  4.4721878e-01,  5.2572525e-01],
             [-7.2360778e-01,  4.4721875e-01, -5.2572525e-01],
             [ 2.7638829e-01,  4.4722012e-01, -8.5064894e-01],
             [ 8.9442581e-01,  4.4721642e-01,  2.4234150e-09],
             [ 7.1490763e-07,  1.0000000e+00, -1.2117078e-09],
             [-2.3079200e-01, -6.6497254e-01,  7.1031433e-01],
             [-1.6245601e-01, -8.5065424e-01,  4.9999544e-01],
             [-8.0574967e-02, -9.6540642e-01,  2.4798816e-01],
             [ 2.1095186e-01, -9.6540624e-01,  1.532644

## Coding 10.4

### Compute the the mean curvature

$
\Delta f = \nabla_{p_i} A = \frac{1}{2}\sum\limits_j (\cot \alpha_j +\cot \beta_j)(p_j - p_i) 
$

In [6]:
print(V)
print(F)

[[ 0.000000e+00 -2.000000e+00  0.000000e+00]
 [ 1.447215e+00 -8.944390e-01  1.051451e+00]
 [-5.527760e-01 -8.944400e-01  1.701298e+00]
 [-1.788852e+00 -8.944310e-01  0.000000e+00]
 [-5.527760e-01 -8.944400e-01 -1.701298e+00]
 [ 1.447215e+00 -8.944390e-01 -1.051451e+00]
 [ 5.527760e-01  8.944400e-01  1.701298e+00]
 [-1.447215e+00  8.944390e-01  1.051451e+00]
 [-1.447215e+00  8.944390e-01 -1.051451e+00]
 [ 5.527760e-01  8.944400e-01 -1.701298e+00]
 [ 1.788852e+00  8.944310e-01  0.000000e+00]
 [ 0.000000e+00  2.000000e+00  0.000000e+00]
 [-4.656430e-01 -1.315038e+00  1.433126e+00]
 [-3.249110e-01 -1.701309e+00  9.999900e-01]
 [-1.552130e-01 -1.935899e+00  4.777050e-01]
 [ 4.063620e-01 -1.935899e+00  2.952360e-01]
 [ 8.506450e-01 -1.701308e+00  6.180230e-01]
 [ 1.219093e+00 -1.315038e+00  8.857130e-01]
 [ 1.063882e+00 -1.004603e+00  1.363425e+00]
 [ 5.257380e-01 -1.051475e+00  1.618023e+00]
 [-5.927900e-02 -1.004604e+00  1.728368e+00]
 [ 1.625458e+00 -1.004601e+00 -5.904750e-01]
 [ 1.70129

In [7]:
def compute_mean_curvature(V, F):
    nabla_P = []
    for p_i, v in enumerate(V):
        # Get neighbouring faces
        Nf = [fi for fi, f in enumerate(F) if p_i in f]
        # Get all other points
        P_j = jnp.unique(jnp.array([[pj for pj in F[f] if pj != p_i] for f in Nf]).flatten())
        # Compute normals
        nabla_p = jnp.zeros(3)
        for p_j in P_j:
            # Assume all vertices in P_j are part of exactly two faces
            f0, f1 = [f for f in Nf if p_j in F[f]]
            p_k = [fi for fi in F[f0] if fi != p_j and fi != p_i][0]
            p_m = [fi for fi in F[f1] if fi != p_j and fi != p_i][0]
            
            p_ki = V[p_i] - V[p_k]
            p_kj = V[p_j] - V[p_k]
            
            p_mi = V[p_i] - V[p_m]
            p_mj = V[p_j] - V[p_m]
            
            alpha = jnp.arccos(jnp.dot(p_ki, p_kj) / (jnp.linalg.norm(p_ki) * jnp.linalg.norm(p_kj)))
            beta  = jnp.arccos(jnp.dot(p_mi, p_mj) / (jnp.linalg.norm(p_mi) * jnp.linalg.norm(p_mj)))

            nabla_p += (1./jnp.tan(alpha) + 1./jnp.tan(beta)) * (V[p_i] - V[p_j])
            
        nabla_p *= 0.5
        nabla_P.append(nabla_p)
    return jnp.array(nabla_P)

def compute_mean_curvature_jax_old(V, F):
    A = jnp.zeros(V.shape[0])
    for vi, v in enumerate(V):
        # Get neighbouring faces
        #Nf = [fi for fi, f in enumerate(F) if vi in f]
        Nf = Nfs[vi]
        # Compute normals
        for fi in Nf:
            f = F[fi]
            #N = compute_normal(V[f[1]] - V[f[0]], V[f[2]] - V[f[0]])
            N = jnp.cross(V[f[1]] - V[f[0]], V[f[2]] - V[f[0]])
            A = jax.ops.index_add(A, vi, jnp.linalg.norm(N) * 0.5)
    return A

def compute_mean_curvature_jax(V, F):
    A = [None for i in range(V.shape[0])]#jnp.zeros(V.shape[0])
    for vi, v in enumerate(V):
        #Nf = [fi for fi, f in enumerate(F) if vi in f]
        Nf = Nfs[vi]
        ff = F[Nf]
        v1 = V[ff[:,1]] - V[ff[:,0]]
        v2 = V[ff[:,2]] - V[ff[:,0]]
        NN = vmap(jnp.cross)(v1, v2)
        A[vi] = jnp.sum(jnp.linalg.norm(NN, axis=1) * 0.5)

    return jnp.array(A)

def compute_mean_curvature_jax_new(V, F):
    #A = [None for i in range(V.shape[0])]
    #import numpy as nonp
    #A = nonp.empty(V.shape[0])
    #@partial(jnp.vectorize, signature='(n,m),(m)->(n)')
    def f(VFrow): #, V, F, Nfs
        #print(vi)
        #Nf = Nfs[vi]
        Nf = jnp.where(VFrow)
        ff = F[Nf]
        v1 = V[ff[:,1]] - V[ff[:,0]]
        v2 = V[ff[:,2]] - V[ff[:,0]]
        NN = compute_normal_vmap(v1, v2)
        return jnp.sum(jnp.linalg.norm(NN, axis=1) * 0.5)
        
    #h = jnp.vectorize(f, signature="(m)->()") # I can never figure out the signature
    # though, it complains about getting a tracer
    #A = jnp.apply_along_axis(f, 1, VF) # same as above, complains about getting a tracer
    A = jnp.array([f(VF[vi]) for vi in jnp.arange(V.shape[0])]) # I don't want to do this
    #h = vmap(f) #???

    #g = functools.partial(f, V=V, F=F, Nfs=Nfs)
    # This is a probably wrong, other way of doing the below...
    #g = lambda vi: f(vi, V, F, Nfs) # nope
    
    #h = vmap(g) # I have ...
    #A = h(Nfs)  # ... tried a bit ...
    #A = h(VF)   # ... of things.

    return A

now = time.time()
analytical = compute_mean_curvature(V, F)
then = time.time()
print("Analytical time spent:", then - now)
now = time.time()
dcompute_mean_curvature_jax_old = jacfwd(compute_mean_curvature_jax_old)
autodiff_old = dcompute_mean_curvature_jax_old(V, F)
autodiff_old = jnp.array([autodiff_old[i, i, :] for i in range(V.shape[0])])
then = time.time()
print("Autodiff time spent old:", then - now)
now = time.time()
dcompute_mean_curvature_jax = jacfwd(compute_mean_curvature_jax)
autodiff = dcompute_mean_curvature_jax(V, F)
autodiff = jnp.array([autodiff[i, i, :] for i in range(V.shape[0])])
then = time.time()
print("Autodiff time spent new:", then - now)
now = time.time()
dcompute_mean_curvature_jax_new = jacfwd(compute_mean_curvature_jax_new)
autodiff_new = dcompute_mean_curvature_jax_new(V, F)
autodiff_new = jnp.array([autodiff_new[i, i, :] for i in range(V.shape[0])])
then = time.time()
print("Autodiff time spent newer:", then - now)

Analytical time spent: 25.508700847625732
Autodiff time spent old: 21.232945919036865
Autodiff time spent new: 4.404656887054443
Autodiff time spent newer: 4.176539421081543


In [8]:
print(analytical)
print(autodiff_old)
print(autodiff)

[[-2.68220901e-07 -2.30016991e-01  0.00000000e+00]
 [ 1.66447356e-01 -1.02871060e-01  1.20929629e-01]
 [-6.35757148e-02 -1.02871060e-01  1.95669174e-01]
 [-2.05733627e-01 -1.02867544e-01  0.00000000e+00]
 [-6.35757446e-02 -1.02871090e-01 -1.95669174e-01]
 [ 1.66447356e-01 -1.02871060e-01 -1.20929629e-01]
 [ 6.35757595e-02  1.02871090e-01  1.95669189e-01]
 [-1.66447371e-01  1.02871031e-01  1.20929614e-01]
 [-1.66447356e-01  1.02871031e-01 -1.20929614e-01]
 [ 6.35757595e-02  1.02871060e-01 -1.95669189e-01]
 [ 2.05733627e-01  1.02867544e-01  0.00000000e+00]
 [ 2.68220901e-07  2.30017006e-01  0.00000000e+00]
 [-6.55584931e-02 -1.90183967e-01  2.01771289e-01]
 [-5.26096523e-02 -2.75475174e-01  1.61918312e-01]
 [-2.32449472e-02 -2.74806023e-01  7.15416968e-02]
 [ 6.08569682e-02 -2.74805218e-01  4.42149639e-02]
 [ 1.37734503e-01 -2.75471568e-01  1.00069165e-01]
 [ 1.71636492e-01 -1.90183133e-01  1.24699891e-01]
 [ 1.48392737e-01 -1.43692583e-01  1.96245104e-01]
 [ 8.51278156e-02 -1.70255125e-

In [9]:
print(analytical)
print(autodiff)

[[-2.68220901e-07 -2.30016991e-01  0.00000000e+00]
 [ 1.66447356e-01 -1.02871060e-01  1.20929629e-01]
 [-6.35757148e-02 -1.02871060e-01  1.95669174e-01]
 [-2.05733627e-01 -1.02867544e-01  0.00000000e+00]
 [-6.35757446e-02 -1.02871090e-01 -1.95669174e-01]
 [ 1.66447356e-01 -1.02871060e-01 -1.20929629e-01]
 [ 6.35757595e-02  1.02871090e-01  1.95669189e-01]
 [-1.66447371e-01  1.02871031e-01  1.20929614e-01]
 [-1.66447356e-01  1.02871031e-01 -1.20929614e-01]
 [ 6.35757595e-02  1.02871060e-01 -1.95669189e-01]
 [ 2.05733627e-01  1.02867544e-01  0.00000000e+00]
 [ 2.68220901e-07  2.30017006e-01  0.00000000e+00]
 [-6.55584931e-02 -1.90183967e-01  2.01771289e-01]
 [-5.26096523e-02 -2.75475174e-01  1.61918312e-01]
 [-2.32449472e-02 -2.74806023e-01  7.15416968e-02]
 [ 6.08569682e-02 -2.74805218e-01  4.42149639e-02]
 [ 1.37734503e-01 -2.75471568e-01  1.00069165e-01]
 [ 1.71636492e-01 -1.90183133e-01  1.24699891e-01]
 [ 1.48392737e-01 -1.43692583e-01  1.96245104e-01]
 [ 8.51278156e-02 -1.70255125e-