<a href="https://colab.research.google.com/github/andreusjh99/Heat-Kernel-Signature/blob/master/Optimisation_with_HKS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Optimisation with HKS
This notebook is to demonstrate an application of Heat Kernel Signature (HKS). A smoothed mesh is transformed to its pre-smoothed form (the ground truth mesh) via optimisation of the error between the HKS of the smoothed mesh and the HKS of the corresponding ground truth mesh. The optimsation is done with gradient descent.

This application is meant to show that HKS is representative of a geometry.

The notebook uses the HKS implementation in [this repo.](https://github.com/andreusjh99/Heat-Kernel-Signature) 

**Brief steps**:
1.   A bunny mesh is loaded (as the ground truth).
2.   A copy of the mesh is smoothed to create the smoothed mesh.
3.   The HKS for the original mesh is computed (ground truth).
4.   Gradient descent is implemented to update and transform the smoothed mesh back into the original mesh.





#### **Setup**

Import the libraries needed and clone the repository for the mesh and implementation of HKS

In [1]:
!pip install tensorflow_graphics
!pip install trimesh

!pip install kaleido
!pip install plotly==4.11.0

Collecting tensorflow_graphics
[?25l  Downloading https://files.pythonhosted.org/packages/37/60/f1e68da284a16e11db859ff2bb4ac4b8b38893e903d43d846feef6daa3d5/tensorflow_graphics-2020.5.20-py2.py3-none-any.whl (342kB)
[K     |████████████████████████████████| 348kB 2.8MB/s 
Collecting trimesh>=2.37.22
[?25l  Downloading https://files.pythonhosted.org/packages/bb/c2/37fa9490f2092708544452a427dc3b8cff20d0283532f7a13dd7bf05743f/trimesh-3.8.10-py3-none-any.whl (625kB)
[K     |████████████████████████████████| 634kB 8.8MB/s 
Collecting tqdm>=4.45.0
[?25l  Downloading https://files.pythonhosted.org/packages/bd/cf/f91813073e4135c1183cadf968256764a6fe4e35c351d596d527c0540461/tqdm-4.50.2-py2.py3-none-any.whl (70kB)
[K     |████████████████████████████████| 71kB 5.5MB/s 
[?25hCollecting psutil>=5.7.0
[?25l  Downloading https://files.pythonhosted.org/packages/aa/3e/d18f2c04cf2b528e18515999b0c8e698c136db78f62df34eee89cee205f1/psutil-5.7.2.tar.gz (460kB)
[K     |█████████████████████████████

In [2]:
!git clone https://github.com/andreusjh99/Heat-Kernel-Signature.git

Cloning into 'Heat-Kernel-Signature'...
remote: Enumerating objects: 37, done.[K
remote: Counting objects: 100% (37/37), done.[K
remote: Compressing objects: 100% (35/35), done.[K
remote: Total 37 (delta 14), reused 0 (delta 0), pack-reused 0[K
Unpacking objects: 100% (37/37), done.


In [3]:
%cd Heat-Kernel-Signature

/content/Heat-Kernel-Signature


In [4]:
!git pull

Already up to date.


In [5]:
%ls

bunny4k.obj               hks_cow.png    Implementation_of_HKS.ipynb
cow.obj                   homer4k.obj    README.md
heat_kernel_signature.py  homer_hks.png


In [6]:
import numpy as np
import tensorflow as tf
import tensorflow_graphics as tfg
import trimesh

import plotly.graph_objects as go
from IPython.display import Image

import heat_kernel_signature as hks

#### **Helper functions**

In [18]:
def plot_mesh(bunny, width=400, height=400, cam_x = 1.25, cam_y = -1.25, wireframe = False, mode="plot"):
    """Plot the mesh of an object."""

    # mode is "plot", "image", "gif"
    y=bunny.vertices[:, 0]
    z=bunny.vertices[:, 1]
    x=bunny.vertices[:, 2]

    i = bunny.faces[:, 0]
    j = bunny.faces[:, 1]
    k = bunny.faces[:, 2]

    pl_mygrey=[0, 'rgb(140,140,140)'], [1., 'rgb(255,255,255)']

    mesh = go.Mesh3d(
            # 8 vertices of a cube
            x = x,
            y = y,
            z = z,
            intensity= z,
            colorscale = pl_mygrey,
            flatshading = True,
            i = i,
            j = j,
            k = k,
            showscale=False
        )

    mesh.update(cmin=-7,# atrick to get a nice plot (z.min()=-3.31909)
        lighting=dict(ambient=0.18,
                        diffuse=1,
                        fresnel=0.1,
                        specular=1,
                        roughness=0.05,
                        facenormalsepsilon=1e-15,
                        vertexnormalsepsilon=1e-15),
        lightposition=dict(x=100, y=200, z=0)
        )

    if wireframe == True:
        triangles = np.vstack((i,j,k)).T
        vertices = np.vstack((x,y,z)).T
        tri_points = vertices[triangles]

        Xe = []
        Ye = []
        Ze = []
        for T in tri_points:
            Xe.extend([T[k%3][0] for k in range(4)]+[ None])
            Ye.extend([T[k%3][1] for k in range(4)]+[ None])
            Ze.extend([T[k%3][2] for k in range(4)]+[ None])

        lines = go.Scatter3d(
                x=Xe,
                y=Ye,
                z=Ze,
                mode='lines',
                name='',
                line=dict(color='rgb(70,70,70)', width=1)
            )

    if wireframe == True:
        data=[mesh, lines]
    else:
        data=[mesh]

    if mode == "gif":
        return data
    
    camera = dict(
        up=dict(x=0, y=0, z=1),
        center=dict(x=0, y=0, z=0),
        eye=dict(x=cam_x, y=cam_y, z=0)
    )

    layout = go.Layout(
            font=dict(size=16, color='white'),
            width=width,
            height=height,
            scene_xaxis_visible=False,
            scene_yaxis_visible=False,
            scene_zaxis_visible=False,
            paper_bgcolor='rgb(50,50,50)',
            scene_camera = camera
            )

    fig = go.Figure(data=data, layout = layout)

    if mode == "image":
        bunny_img = fig.to_image(format="png", engine="kaleido")
        return bunny_img
    elif mode == "plot":
        fig.show()

In [8]:
def normalise(mat):
    """Normalise a matrix"""

    max = tf.math.reduce_max(mat)
    min = tf.math.reduce_min(mat)
    n_mat = tf.math.truediv((mat - min), (max - min))

    return n_mat

## Load the bunny mesh

In [9]:
bunny = trimesh.load("bunny4k.obj")

In [None]:
# Image(bunny_img)

In [11]:
bunny.vertices = normalise(bunny.vertices)*10
assert np.asarray(np.max(bunny.vertices)) == 10.0

In [12]:
plot_mesh(bunny, 400, 400)

## Create a smoothed bunny mesh

This is done using the trimesh `filter_laplacian` method.

In [13]:
bunny_s = trimesh.load("bunny4k.obj")
assert bunny_s.vertex_neighbors[0] == bunny.vertex_neighbors[0]

In [14]:
bunny_s.vertices = normalise(bunny_s.vertices)*10
assert np.asarray(np.max(bunny.vertices)) == 10.0

In [15]:
trimesh.smoothing.filter_laplacian(bunny_s, lamb=0.5, iterations=10, implicit_time_integration=False, volume_constraint=True, laplacian_operator=None)

<trimesh.Trimesh(vertices.shape=(2021, 3), faces.shape=(3999, 3))>

In [16]:
assert bunny_s.vertex_neighbors[0] == bunny.vertex_neighbors[0]

In [17]:
plot_mesh(bunny_s, 400, 400)

## Compute the normalised ground truth HKS (i.e. HKS of the original mesh

In [19]:
t = 0.5

In [36]:
hks_bunny = hks.hks_t(bunny.vertices, t)

In [38]:
# This is to get the normalisation constants for normalising the HKS of the smoothed mesh during each iteration.
high = tf.math.reduce_max(hks_bunny)
low = tf.math.reduce_min(hks_bunny)
print(np.asarray(high), np.asarray(low))

0.6178457747168905 0.6108167566172786


In [39]:
n_hks_bunny = normalise(hks_bunny)

In [40]:
bunny.visual.vertex_colors = trimesh.visual.color.interpolate(n_hks_bunny, "jet")
print("t = ", t)
bunny.show(height = 200)

t =  0.5


## Optimisation

In [24]:
def smooth_bunny():
    bunny_s = trimesh.load("bunny4k.obj")
    assert bunny_s.vertex_neighbors[0] == bunny.vertex_neighbors[0]
    bunny_s.vertices = normalise(bunny_s.vertices)*10
    trimesh.smoothing.filter_laplacian(bunny_s, lamb=0.5, iterations=10, implicit_time_integration=False, volume_constraint=True, laplacian_operator=None)
    return bunny_s

In [41]:
bunny_s = smooth_bunny()

In [42]:
plot_mesh(bunny_s, 400, 400)

Define the hyperparameters: 
t: time scale

1.   `t`: time scale
2.   `l`: regularisation weight
3. `alpha`: learning rate
4. `num_iters`: number of iterations



In [43]:
t = 0.5
l = 0.01
alpha = 0.1
num_iters = 25

We will keep the mesh data in each iteration in the `pltdata` list for visualisation later.

In [44]:
ogdata = plot_mesh(bunny_s, mode="gif")
pltdata = [ogdata]

In [45]:
for i in range(num_iters):
    print("iter: ", i+1)
    with tf.GradientTape() as g:
        x = tf.constant(bunny_s.vertices)
        g.watch(x)
        hks_bunny_s = hks.hks_t(x, t)
        n_hks_bunny_s = tf.math.truediv((hks_bunny_s - low), (high - low))
        err = tf.math.reduce_sum(tf.math.square(n_hks_bunny_s - n_hks_bunny))
        print("error: ", np.asarray(err))

    dy_dx = g.gradient(err, x) + l*x
    bunny_s.vertices -= alpha*dy_dx

    pltdata.append(plot_mesh(bunny_s, mode="gif"))
    
    # print(bunny_n.vertices)
    # max = tf.math.reduce_max(bunny_s.vertices)
    # print("max coor value: ", np.asarray(max))

iter:  1
error:  19.453252574607596
iter:  2
error:  5.514294695628189
iter:  3
error:  2.505911850392838
iter:  4
error:  1.5367448310199177
iter:  5
error:  1.0910595091046076
iter:  6
error:  0.8390725872253312
iter:  7
error:  0.6775358130556612
iter:  8
error:  0.5683004305554854
iter:  9
error:  0.49116631717866593
iter:  10
error:  0.43503735737641674
iter:  11
error:  0.39300251654422286
iter:  12
error:  0.3608058209339705
iter:  13
error:  0.3356391500357199
iter:  14
error:  0.3156298663581716
iter:  15
error:  0.2994739259420894
iter:  16
error:  0.28624966979753974
iter:  17
error:  0.2752888432836406
iter:  18
error:  0.26610028215319437
iter:  19
error:  0.25831671043620735
iter:  20
error:  0.25165989644225556
iter:  21
error:  0.2459159861681103
iter:  22
error:  0.24091837216724765
iter:  23
error:  0.23653534104571328
iter:  24
error:  0.23266117193716018
iter:  25
error:  0.22920958959718618


## Visualisation

Parameters for visualisation

In [46]:
cam_x = 1.25
cam_y = -1.25
width = 500
height = 500

# angle of rotation in rads
# change this will change the initial orientation of the mesh
angle = 0

In [47]:
rot = np.array([[np.cos(angle), np.sin(angle)], [-np.sin(angle), np.cos(angle)]])
cam = np.array([cam_x, cam_y])

cam = np.matmul(rot, cam)

camera = dict(
        up=dict(x=0, y=0, z=1),
        center=dict(x=0, y=0, z=0),
        eye=dict(x=cam[0], y=cam[1], z=0)
    )

layouts = []
iters = len(pltdata)
for i in range(iters):

    layout = go.Layout(
            font=dict(size=16, color='grey'),
            width=width,
            height=height,
            scene_xaxis_visible=False,
            scene_yaxis_visible=False,
            scene_zaxis_visible=False,
            paper_bgcolor='rgb(50,50,50)',
            scene_camera = camera,
            title=str(i),
            updatemenus=[dict(
                type="buttons",
                buttons=[dict(label="Play",
                        method="animate",
                        args=[None])]
                )] if i == 0 else None,
        )
    layouts.append(layout)

fig = go.Figure(
    data=pltdata[0],
    layout=layouts[0],
    frames=[go.Frame(data=pltdata[j], layout = layouts[j])
        for j in range(1, iters)]
)

fig.show()