In [44]:
import jax
import jax.numpy as jnp
from jax import grad, jacfwd

In [2]:
DATA = 'meshes/'
examplemesh = 'box.obj'

def objloader(folder, fname):
    V = []
    F = []
    with open(folder + fname, 'r') as f:
        lines = f.readlines()
        for l in lines:
            token = l.split(' ')[0]
            if token == 'v':
                V.append(jnp.array([float(v) for v in l.split(' ')[1:]]))
            if token == 'f':
                F.append(jnp.array([int(f.split('/')[0]) - 1 for f in l.split(' ')[1:]]))
    print(F)
    V = jnp.array(V)
    F = jnp.array(F, dtype=jnp.int32)
    
    return V, F

def get_edge(V, e0, e1):
    return V[e1] - V[e0]

def normalize(v):
    return v / jnp.linalg.norm(v)

def compute_normal_normalized(e0, e1):
    return normalize(jnp.cross(e0, e1))

def compute_normal(e0, e1):
    return jnp.cross(e0, e1)

V, F = objloader(DATA, examplemesh)
print(F)

[DeviceArray([0, 1, 2], dtype=int32), DeviceArray([1, 0, 3], dtype=int32), DeviceArray([4, 5, 6], dtype=int32), DeviceArray([5, 4, 7], dtype=int32), DeviceArray([4, 1, 3], dtype=int32), DeviceArray([1, 4, 6], dtype=int32), DeviceArray([1, 5, 2], dtype=int32), DeviceArray([5, 1, 6], dtype=int32), DeviceArray([5, 0, 2], dtype=int32), DeviceArray([0, 5, 7], dtype=int32), DeviceArray([4, 0, 7], dtype=int32), DeviceArray([0, 4, 3], dtype=int32)]
[[0 1 2]
 [1 0 3]
 [4 5 6]
 [5 4 7]
 [4 1 3]
 [1 4 6]
 [1 5 2]
 [5 1 6]
 [5 0 2]
 [0 5 7]
 [4 0 7]
 [0 4 3]]




## Coding 10.1

### Compute the uniformly weighted vertex normals given by

$
N_U = \sum\limits_i N_i
$

In [3]:
def compute_uniformly_weighted_normals(V, F):
    Nu = []
    for vi, v in enumerate(V):
        # Get neighbouring faces
        Nf = [fi for fi, f in enumerate(F) if vi in f]
        # Compute normals
        Nv = jnp.zeros(3)
        for fi in Nf:
            f = F[fi]
            Nv += compute_normal_normalized(V[f[1]] - V[f[0]], V[f[2]] - V[f[0]])
        Nu.append(normalize(Nv))
    return jnp.array(Nu)
        
compute_uniformly_weighted_normals(V, F)

DeviceArray([[-0.57735026,  0.57735026, -0.57735026],
             [ 0.57735026,  0.57735026,  0.57735026],
             [ 0.57735026,  0.57735026, -0.57735026],
             [-0.57735026,  0.57735026,  0.57735026],
             [-0.57735026, -0.57735026,  0.57735026],
             [ 0.57735026, -0.57735026, -0.57735026],
             [ 0.57735026, -0.57735026,  0.57735026],
             [-0.57735026, -0.57735026, -0.57735026]], dtype=float32)

## Coding 10.2

### Compute the vertex normal using face area weights

$
N_\nu = \sum\limits_i N_i \cdot \frac{A_i}{\sum\limits_j A_j} = \frac{1}{\sum\limits_j A_j}\sum\limits_i N_i \cdot A_i
$

In [4]:
def compute_area_weighted_normals(V, F):
    Nu = []
    for vi, v in enumerate(V):
        # Get neighbouring faces
        Nf = [fi for fi, f in enumerate(F) if vi in f]
        # Compute normals
        Nv = jnp.zeros(3)
        A = 0
        for fi in Nf:
            f = F[fi]
            N = compute_normal(V[f[1]] - V[f[0]], V[f[2]] - V[f[0]])
            Ai = jnp.linalg.norm(N) * 0.5
            Nv += normalize(N) * Ai
            A += Ai
        Nv /= A
        Nu.append(normalize(Nv))
    return jnp.array(Nu)

compute_area_weighted_normals(V, F)

DeviceArray([[-0.04987546,  0.9975093 , -0.04987546],
             [ 0.04987546,  0.9975093 ,  0.04987546],
             [ 0.04987546,  0.9975093 , -0.04987546],
             [-0.04987546,  0.9975093 ,  0.04987546],
             [-0.04987546, -0.9975093 ,  0.04987546],
             [ 0.04987546, -0.9975093 , -0.04987546],
             [ 0.04987546, -0.9975093 ,  0.04987546],
             [-0.04987546, -0.9975093 , -0.04987546]], dtype=float32)

## Coding 10.3

### Compute the vertex normal using tip angle

$
N_\nu = \sum\limits_i \theta_i N_i
$

In [5]:
def compute_tip_angle_weighted_normals(V, F):
    Nu = []
    for vi, v in enumerate(V):
        # Get neighbouring faces
        Nf = [fi for fi, f in enumerate(F) if vi in f]
        # Compute normals
        Nv = jnp.zeros(3)
        for fi in Nf:
            f = F[fi]
            N = compute_normal(V[f[1]] - V[f[0]], V[f[2]] - V[f[0]])
            # Figure out "where" the vertex is located
            other_nodes = [fv for fv in f if fv != vi]
            v0 = V[other_nodes[0]] - V[vi]
            v1 = V[other_nodes[1]] - V[vi]
            theta_i = jnp.arccos(jnp.dot(v0, v1) / (jnp.linalg.norm(v0) * jnp.linalg.norm(v1)))
            Nv += normalize(N) * theta_i
        Nu.append(normalize(Nv))
    return jnp.array(Nu)

compute_tip_angle_weighted_normals(V, F)

DeviceArray([[-0.57735026,  0.5773503 , -0.57735026],
             [ 0.57735026,  0.5773503 ,  0.57735026],
             [ 0.57735026,  0.57735026, -0.57735026],
             [-0.57735026,  0.57735026,  0.57735026],
             [-0.57735026, -0.5773503 ,  0.57735026],
             [ 0.57735026, -0.5773503 , -0.57735026],
             [ 0.57735026, -0.57735026,  0.57735026],
             [-0.57735026, -0.57735026, -0.57735026]], dtype=float32)

## Coding 10.4

### Compute the the mean curvature

$
\Delta f = \nabla_{p_i} A = \frac{1}{2}\sum\limits_j (\cot \alpha_j +\cot \beta_j)(p_j - p_i) 
$

In [42]:
def compute_mean_curvature(V, F):
    nabla_P = []
    for p_i, v in enumerate(V):
        # Get neighbouring faces
        Nf = [fi for fi, f in enumerate(F) if p_i in f]
        # Get all other points
        P_j = jnp.unique(jnp.array([[pj for pj in F[f] if pj != p_i] for f in Nf]).flatten())
        # Compute normals
        nabla_p = jnp.zeros(3)
        for p_j in P_j:
            # Assume all vertices in P_j are part of exactly two faces
            f0, f1 = [f for f in Nf if p_j in F[f]]
            p_k = [fi for fi in F[f0] if fi != p_j and fi != p_i][0]
            p_m = [fi for fi in F[f1] if fi != p_j and fi != p_i][0]
            
            p_ki = V[p_i] - V[p_k]
            p_kj = V[p_j] - V[p_k]
            
            p_mi = V[p_i] - V[p_m]
            p_mj = V[p_j] - V[p_m]
            
            alpha = jnp.arccos(jnp.dot(p_ki, p_kj) / (jnp.linalg.norm(p_ki) * jnp.linalg.norm(p_kj)))
            beta  = jnp.arccos(jnp.dot(p_mi, p_mj) / (jnp.linalg.norm(p_mi) * jnp.linalg.norm(p_mj)))
            
            nabla_p += (1./jnp.tan(alpha) + 1./jnp.tan(beta)) * (V[p_j] - V[p_i])
            
        nabla_p *= 0.5
        nabla_P.append(nabla_p)
    return jnp.array(nabla_P)


def compute_mean_curvature_jax(V, F):
    A = jnp.zeros(V.shape[0])
    for vi, v in enumerate(V):
        # Get neighbouring faces
        Nf = [fi for fi, f in enumerate(F) if vi in f]
        # Compute normals
        for fi in Nf:
            f = F[fi]
            N = compute_normal(V[f[1]] - V[f[0]], V[f[2]] - V[f[0]])
            A = jax.ops.index_add(A, vi, jnp.linalg.norm(N) * 0.5)
    return A

analytical = compute_mean_curvature(V, F)

dcompute_mean_curvature_jax = jacfwd(compute_mean_curvature_jax)
autodiff = dcompute_mean_curvature_jax(V, F)
autodiff = jnp.array([autodiff[i, i, :] for i in range(V.shape[0])])

In [43]:
print(analytical)
print(autodiff)

[[ 0.525      -1.0000067   0.525     ]
 [-0.525      -1.0000067  -0.525     ]
 [-0.5250001  -1.0000067   0.5250001 ]
 [ 0.5250001  -1.0000067  -0.5250001 ]
 [ 0.52500004  1.0000067  -0.52500004]
 [-0.52500004  1.0000067   0.52500004]
 [-0.5250001   1.0000067  -0.5250001 ]
 [ 0.5250001   1.0000067   0.5250001 ]]
[[-0.525  1.    -0.525]
 [ 0.525  1.     0.525]
 [ 0.525  1.    -0.525]
 [-0.525  1.     0.525]
 [-0.525 -1.     0.525]
 [ 0.525 -1.    -0.525]
 [ 0.525 -1.     0.525]
 [-0.525 -1.    -0.525]]
