In [1]:
import jax
import jax.numpy as jnp
from jax import grad, jacfwd, vmap

In [33]:
DATA = 'meshes/'
examplemesh = 'sphere2.obj'

def objloader(folder, fname):
    V = []
    F = []
    with open(folder + fname, 'r') as f:
        lines = f.readlines()
        for l in lines:
            token = l.split(' ')[0]
            if token == 'v':
                V.append(jnp.array([float(v) for v in l.split(' ')[1:]]))
            if token == 'f':
                F.append(jnp.array([int(f.split('/')[0]) - 1 for f in l.split(' ')[1:]]))
    #print(F)
    V = jnp.array(V)
    F = jnp.array(F, dtype=jnp.int32)
    
    return V, F

def get_edge(V, e0, e1):
    return V[e1] - V[e0]

def normalize(v):
    return v / jnp.linalg.norm(v)

def compute_normal_normalized(e0, e1):
    return normalize(jnp.cross(e0, e1))

def compute_normal(e0, e1):
    return jnp.cross(e0, e1)

V, F = objloader(DATA, examplemesh)
print(V.shape)
print(F.shape)


(162, 3)
(320, 3)


## Coding 10.1

### Compute the uniformly weighted vertex normals given by

$
N_U = \sum\limits_i N_i
$

In [3]:
def compute_uniformly_weighted_normals(V, F):
    Nu = []
    for vi, v in enumerate(V):
        # Get neighbouring faces
        Nf = [fi for fi, f in enumerate(F) if vi in f]
        # Compute normals
        Nv = jnp.zeros(3)
        for fi in Nf:
            f = F[fi]
            Nv += compute_normal_normalized(V[f[1]] - V[f[0]], V[f[2]] - V[f[0]])
        Nu.append(normalize(Nv))
    return jnp.array(Nu)
        
compute_uniformly_weighted_normals(V, F)

DeviceArray([[-0.57735026,  0.57735026, -0.57735026],
             [ 0.57735026,  0.57735026,  0.57735026],
             [ 0.57735026,  0.57735026, -0.57735026],
             [-0.57735026,  0.57735026,  0.57735026],
             [-0.57735026, -0.57735026,  0.57735026],
             [ 0.57735026, -0.57735026, -0.57735026],
             [ 0.57735026, -0.57735026,  0.57735026],
             [-0.57735026, -0.57735026, -0.57735026]], dtype=float32)

## Coding 10.2

### Compute the vertex normal using face area weights

$
N_\nu = \sum\limits_i N_i \cdot \frac{A_i}{\sum\limits_j A_j} = \frac{1}{\sum\limits_j A_j}\sum\limits_i N_i \cdot A_i
$

In [4]:
def compute_area_weighted_normals(V, F):
    Nu = []
    for vi, v in enumerate(V):
        # Get neighbouring faces
        Nf = [fi for fi, f in enumerate(F) if vi in f]
        # Compute normals
        Nv = jnp.zeros(3)
        A = 0
        for fi in Nf:
            f = F[fi]
            N = compute_normal(V[f[1]] - V[f[0]], V[f[2]] - V[f[0]])
            Ai = jnp.linalg.norm(N) * 0.5
            Nv += normalize(N) * Ai
            A += Ai
        Nv /= A
        Nu.append(normalize(Nv))
    return jnp.array(Nu)

compute_area_weighted_normals(V, F)

DeviceArray([[-0.04987546,  0.9975093 , -0.04987546],
             [ 0.04987546,  0.9975093 ,  0.04987546],
             [ 0.04987546,  0.9975093 , -0.04987546],
             [-0.04987546,  0.9975093 ,  0.04987546],
             [-0.04987546, -0.9975093 ,  0.04987546],
             [ 0.04987546, -0.9975093 , -0.04987546],
             [ 0.04987546, -0.9975093 ,  0.04987546],
             [-0.04987546, -0.9975093 , -0.04987546]], dtype=float32)

## Coding 10.3

### Compute the vertex normal using tip angle

$
N_\nu = \sum\limits_i \theta_i N_i
$

In [5]:
def compute_tip_angle_weighted_normals(V, F):
    Nu = []
    for vi, v in enumerate(V):
        # Get neighbouring faces
        Nf = [fi for fi, f in enumerate(F) if vi in f]
        # Compute normals
        Nv = jnp.zeros(3)
        for fi in Nf:
            f = F[fi]
            N = compute_normal(V[f[1]] - V[f[0]], V[f[2]] - V[f[0]])
            # Figure out "where" the vertex is located
            other_nodes = [fv for fv in f if fv != vi]
            v0 = V[other_nodes[0]] - V[vi]
            v1 = V[other_nodes[1]] - V[vi]
            theta_i = jnp.arccos(jnp.dot(v0, v1) / (jnp.linalg.norm(v0) * jnp.linalg.norm(v1)))
            Nv += normalize(N) * theta_i
        Nu.append(normalize(Nv))
    return jnp.array(Nu)

compute_tip_angle_weighted_normals(V, F)

DeviceArray([[-0.57735026,  0.5773503 , -0.57735026],
             [ 0.57735026,  0.5773503 ,  0.57735026],
             [ 0.57735026,  0.57735026, -0.57735026],
             [-0.57735026,  0.57735026,  0.57735026],
             [-0.57735026, -0.5773503 ,  0.57735026],
             [ 0.57735026, -0.5773503 , -0.57735026],
             [ 0.57735026, -0.57735026,  0.57735026],
             [-0.57735026, -0.57735026, -0.57735026]], dtype=float32)

## Coding 10.4

### Compute the the mean curvature

$
\Delta f = \nabla_{p_i} A = \frac{1}{2}\sum\limits_j (\cot \alpha_j +\cot \beta_j)(p_j - p_i) 
$

In [13]:
print(V)
print(F)

[[-0.5    0.025 -0.5  ]
 [ 0.5    0.025  0.5  ]
 [ 0.5    0.025 -0.5  ]
 [-0.5    0.025  0.5  ]
 [-0.5   -0.025  0.5  ]
 [ 0.5   -0.025 -0.5  ]
 [ 0.5   -0.025  0.5  ]
 [-0.5   -0.025 -0.5  ]]
[[0 1 2]
 [1 0 3]
 [4 5 6]
 [5 4 7]
 [4 1 3]
 [1 4 6]
 [1 5 2]
 [5 1 6]
 [5 0 2]
 [0 5 7]
 [4 0 7]
 [0 4 3]]


In [34]:
def compute_mean_curvature(V, F):
    nabla_P = []
    for p_i, v in enumerate(V):
        # Get neighbouring faces
        Nf = [fi for fi, f in enumerate(F) if p_i in f]
        # Get all other points
        P_j = jnp.unique(jnp.array([[pj for pj in F[f] if pj != p_i] for f in Nf]).flatten())
        # Compute normals
        nabla_p = jnp.zeros(3)
        for p_j in P_j:
            # Assume all vertices in P_j are part of exactly two faces
            f0, f1 = [f for f in Nf if p_j in F[f]]
            p_k = [fi for fi in F[f0] if fi != p_j and fi != p_i][0]
            p_m = [fi for fi in F[f1] if fi != p_j and fi != p_i][0]
            
            p_ki = V[p_i] - V[p_k]
            p_kj = V[p_j] - V[p_k]
            
            p_mi = V[p_i] - V[p_m]
            p_mj = V[p_j] - V[p_m]
            
            alpha = jnp.arccos(jnp.dot(p_ki, p_kj) / (jnp.linalg.norm(p_ki) * jnp.linalg.norm(p_kj)))
            beta  = jnp.arccos(jnp.dot(p_mi, p_mj) / (jnp.linalg.norm(p_mi) * jnp.linalg.norm(p_mj)))

            nabla_p += (1./jnp.tan(alpha) + 1./jnp.tan(beta)) * (V[p_i] - V[p_j])
            
        nabla_p *= 0.5
        nabla_P.append(nabla_p)
    return jnp.array(nabla_P)


def compute_mean_curvature_jax(V, F):
    A = jnp.zeros(V.shape[0])
    for vi, v in enumerate(V):
        # Get neighbouring faces
        Nf = [fi for fi, f in enumerate(F) if vi in f]
        # Compute normals
        for fi in Nf:
            f = F[fi]
            N = compute_normal(V[f[1]] - V[f[0]], V[f[2]] - V[f[0]])
            A = jax.ops.index_add(A, vi, jnp.linalg.norm(N) * 0.5)
    return A

compute_normal_vmap = vmap(compute_normal)

def compute_mean_curvature_jax_new(V, F):
    A = jnp.zeros(V.shape[0])
    for vi, v in enumerate(V):
        # Get neighbouring faces
        Nf = [fi for fi, f in enumerate(F) if vi in f]
        # Compute normals
        #print(F, Nf)
        #for fi in Nf:
        ff = F[Nf]
        #print('here')
        #print(F[Nf].shape, F[Nf[0]].shape)
        v1 = V[ff[:,1]] - V[ff[:,0]]
        v2 = V[ff[:,2]] - V[ff[:,0]]
        NN = compute_normal_vmap(v1, v2)
        #A = jnp.linalg.norm(NN, axis=0)
        A = jax.ops.index_add(A, vi, jnp.linalg.norm(NN, axis=0) * 0.5)
        #print(A.shape)
        #    N = compute_normal(V[f[1]] - V[f[0]], V[f[2]] - V[f[0]])
        #    A = jax.ops.index_add(A, vi, jnp.linalg.norm(N) * 0.5)
    return A

import time
now = time.time()
analytical = compute_mean_curvature(V, F)
then = time.time()
print("Analytical time spent:", then - now)
now = time.time()
dcompute_mean_curvature_jax = jacfwd(compute_mean_curvature_jax)
autodiff = dcompute_mean_curvature_jax(V, F)
autodiff = jnp.array([autodiff[i, i, :] for i in range(V.shape[0])])
then = time.time()
print("Autodiff time spent:", then - now)

Analytical time spent: 43.236945390701294
Autodiff time spent: 29.73582148551941


In [35]:
print(analytical)
print(autodiff)

[[-2.68220901e-07 -2.30016991e-01  0.00000000e+00]
 [ 1.66447356e-01 -1.02871060e-01  1.20929629e-01]
 [-6.35757148e-02 -1.02871060e-01  1.95669174e-01]
 [-2.05733627e-01 -1.02867544e-01  0.00000000e+00]
 [-6.35757446e-02 -1.02871090e-01 -1.95669174e-01]
 [ 1.66447356e-01 -1.02871060e-01 -1.20929629e-01]
 [ 6.35757595e-02  1.02871090e-01  1.95669189e-01]
 [-1.66447371e-01  1.02871031e-01  1.20929614e-01]
 [-1.66447356e-01  1.02871031e-01 -1.20929614e-01]
 [ 6.35757595e-02  1.02871060e-01 -1.95669189e-01]
 [ 2.05733627e-01  1.02867544e-01  0.00000000e+00]
 [ 2.68220901e-07  2.30017006e-01  0.00000000e+00]
 [-6.55584931e-02 -1.90183967e-01  2.01771289e-01]
 [-5.26096523e-02 -2.75475174e-01  1.61918312e-01]
 [-2.32449472e-02 -2.74806023e-01  7.15416968e-02]
 [ 6.08569682e-02 -2.74805218e-01  4.42149639e-02]
 [ 1.37734503e-01 -2.75471568e-01  1.00069165e-01]
 [ 1.71636492e-01 -1.90183133e-01  1.24699891e-01]
 [ 1.48392737e-01 -1.43692583e-01  1.96245104e-01]
 [ 8.51278156e-02 -1.70255125e-

In [28]:
print(analytical)
print(autodiff)

[[ 3.16649675e-08 -2.58487109e-02  7.45058060e-09]
 [ 1.87046193e-02 -1.15602612e-02  1.35896355e-02]
 [-7.14459643e-03 -1.15605965e-02  2.19891742e-02]
 ...
 [ 1.67238973e-02 -3.37766707e-02  1.91537440e-02]
 [ 1.93499848e-02 -2.74843723e-02  2.04089321e-02]
 [ 1.96325853e-02 -2.02727988e-02  1.95527785e-02]]
[[-7.4505806e-09 -2.5848709e-02  0.0000000e+00]
 [ 1.8704623e-02 -1.1560261e-02  1.3589628e-02]
 [-7.1446002e-03 -1.1560600e-02  2.1989167e-02]
 ...
 [ 1.6723894e-02 -3.3776663e-02  1.9153751e-02]
 [ 1.9349955e-02 -2.7484374e-02  2.0408962e-02]
 [ 1.9632582e-02 -2.0272801e-02  1.9552790e-02]]
