## Fitting a mesh of an object to a spherical mesh using Pytorch3d ##

In this notebook we are going to begin our journey in Pytorch3D. As a starter we are going to try to fit a mesh to a spherically initialized mesh. 

**Disclaimer** : I am a noob in Pytorch3d hence you will find my quite a bit of a struggle game. But, as always I will try my best to create the best of code.

*(This notebook and the following notebooks on Pytorch3d are heavily influenced by the tutorials provided by Pytorch3d given in: https://pytorch3d.org/tutorials/ )* 

Installing Pytorch3d in Windows is a pain. Hence, I will be using Kaggle Kernels to help me in my task.

To set up Pytorch3d in Kaggle we need to install the necessary packages for it. We are going to take inspiration from [Kaggle_Pytorch3d](http://https://www.kaggle.com/code/aynur19/pytorch3d-3d-model-rendering/notebook?scriptVersionId=79722958) and go ahead with our work.

In [1]:
## Code copied from https://www.kaggle.com/code/aynur19/pytorch3d-3d-model-rendering/notebook?scriptVersionId=79722958 ##

import os
import sys
import torch
from packaging import version

need_pytorch3d = False
pytorch3dVersion = '0.6.0'

try:
    import pytorch3d as p3d
    if version.parse(p3d.__version__) < version.parse(pytorch3dVersion):
        need_pytorch3d = True
except ModuleNotFoundError:
    need_pytorch3d = True

!curl -LO https://github.com/NVIDIA/cub/archive/1.10.0.tar.gz
!tar xzf 1.10.0.tar.gz
os.environ["CUB_HOME"] = os.getcwd() + "/cub-1.10.0"
!pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'

With all that aside, we can carry ahead on our project by importing all the necessary packages.

In [3]:
!pip install wget

In [22]:
## Importing the necessary packages ##

import wget
import torch
import pytorch3d
from pytorch3d.io import load_obj , save_obj
from pytorch3d.structures import Meshes
from pytorch3d.utils import ico_sphere
from pytorch3d.loss import chamfer_distance , mesh_edge_loss , mesh_laplacian_smoothing , mesh_normal_consistency
from pytorch3d.ops import sample_points_from_meshes

from tqdm import tqdm
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

With that we are done importing the necessary packages.

Now, in this starter project we are going to fit a spherical mesh to a target mesh. So we must have some target mesh. There are many websites to get a target mesh, but I found the best one is this: https://people.sc.fsu.edu/~jburkardt/data/obj/obj.html . We can download one of the meshes and use it as our target mesh.

We are going to do just that. But how? By using **wget**. So, make sure you have wget installed using <code>pip install wget</code>.

For this work I chose the violin case obj from the website, so up next we will do just that.

In [54]:
## Downloading the target .obj file ##

link = 'https://people.sc.fsu.edu/~jburkardt/data/obj/violin_case.obj'

filename = wget.download(link)

print('\n' + filename , 'is downloaded!')

If you have downloaded once don't download again, wget will keep on downloading it and your disk will get fill. So, just run the above code once and set your filename later to use it.

Now we are going to load the .obj file using pytorch3d. The <code>load_obj</code> in <code>pytorch3d.io</code> loads the .obj and returns 3 things - vertices, faces and aux. 

In [81]:
## Setting the filename ##
## Important if you have downloaded something before ##

filename = 'violin_case.obj'

## Loading the .obj file ##

vertices , faces , aux = load_obj(filename)

print('Target .obj is loaded!')

Now with this loaded, we need to create the Target Mesh out of the information we have. We are going to also normalize the vertices values for faster convergence.

Along with that we are going to create the initialized spherical mesh. 

After doing that we will also visualize the meshes, using a custom visualization function.

In [82]:
## Setting the default device ##

device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')

## Normalizing the vertices ##

## Finding the mean ##
vertices_mean = vertices.mean(0)

## Finding the max val ##
vertices_max_val = max(vertices.abs().max(0)[0])

vertices = vertices - vertices_mean / vertices_max_val


## Shifting the vertices and the corresponding faces idx to gpu ##
vertices = vertices.to(device)
faces_verts_idx = faces.verts_idx.to(device) 

## Creating target mesh ##
target_mesh = Meshes(verts = [vertices] , faces = [faces_verts_idx])

## Creating initial mesh ##
init_mesh = ico_sphere(4 , device = device)

In [83]:
## Visualizing meshes ##
## Code inspired from : https://pytorch3d.org/tutorials/deform_source_mesh_to_target_mesh ##

def mesh_visualization(mesh , title = ''):
    '''
    Given a mesh, creates a visualization based
    on fixed number sampling points.
    '''
    
    ## Setting the number of sampling points ## 
    num_points = 15000
    
    ## Creates the points in the mesh ##
    mesh_points = sample_points_from_meshes(mesh , num_samples = num_points)
    
    ## Grabs the x,y,z points ##
    x, y, z = mesh_points.clone().detach().cpu().squeeze().unbind(1) 
    
    ## Setting the 
    fig = plt.figure(figsize=(5, 5))
    
    ## Setting the 3d plot ##
    ax = Axes3D(fig)
    
    ## Creating the scatter plot ##
    ax.scatter3D(x, y, z)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_zlabel('z')
    ax.set_title(title)
    
    plt.show()

Now let's visualize!!

In [84]:
## Plotting the target mesh ##
mesh_visualization(target_mesh , 'Target Mesh')
mesh_visualization(init_mesh , 'Init Mesh')

Amazing!!
We can now visualize the meshes. 

Next up what we need to do is set up the optimizer.

Instead of directly optimizing the model as we do in Pytorch, here we need to optimize a mesh. But we can't just do it directly. We can optimize the mesh and change the vertex values by using the <code>.offset_verts</code>. The <code>.offset_verts</code> allows for changing the vertices by providing the offset or change that the vertices need to imbibe. This change can be calculated using the loss functions and the gradients. But, at the first epoch we need to set it as zero.

So what do we need to do next?
* Create a value tensor which will provide the values for changing the vertices through offset_verts.
* Create a optimizer with the parameter set as the value tensor created.

In [85]:
## Creating the value tensor ##

vertices_value_changer_tensor = torch.full(init_mesh.verts_packed().shape , 0.0 , device = device , requires_grad = True)

## Setting the optimizer ##

optim = torch.optim.Adam([vertices_value_changer_tensor])

Now, we must set up our loss function. Our loss function is very simple. It is the Chamfer Loss. But, to get a better and smoother mesh, we need to provide some regularizers. We provide regularizers using the <code>mesh_edge_loss , mesh_laplacian_smoothing , mesh_normal_consistency</code> functions.

In [86]:
## Loss function definition ##

def calculate_loss(initial_pointcloud , target_pointcloud , initial_mesh):
    '''
    Calculates the chamfered loss between the pointcloud samples of meshes.
    Along with that adds in the regularizer given the mesh.
    '''
    chamfer_loss , _ = chamfer_distance(initial_pointcloud , target_pointcloud)
    
    edge_reg = mesh_edge_loss(initial_mesh)
    
    laplacian_reg = mesh_laplacian_smoothing(initial_mesh)
    
    normal_reg = mesh_normal_consistency(initial_mesh)
    
    total_loss = chamfer_loss + edge_reg + 0.1 * laplacian_reg + 0.01 * normal_reg
    
    return total_loss

Now we are all set, lets fit the training loop.

In [87]:
## Setting up the training loop ##

def fit(num_iter):
    '''
    Tries fitting the 
    '''
    
    loop = tqdm(range(num_iter))
    
    for each_iter in loop:
        
        optim.zero_grad()
        
        updated_init_mesh = init_mesh.offset_verts(vertices_value_changer_tensor)
        
        initial_pointcloud = sample_points_from_meshes(updated_init_mesh , num_samples = 15000)
        
        target_pointcloud = sample_points_from_meshes(target_mesh , num_samples = 15000)
        
        loss = calculate_loss(initial_pointcloud , target_pointcloud , updated_init_mesh)
        
        loop.set_description('Epoch : {} / {}'.format(each_iter + 1 , num_iter))
        
        loop.set_postfix(Loss = loss.item())
        
        if (each_iter + 1) % 500 == 0:
            
            mesh_visualization(target_mesh , 'Epoch : {} / {} --> Target Mesh'.format(each_iter  + 1, num_iter))
            mesh_visualization(updated_init_mesh , 'Epoch : {} / {} -- > Init Mesh'.format(each_iter + 1 , num_iter))
        
        loss.backward()
        
        optim.step()
    
    final_vertices, final_faces = updated_init_mesh.get_mesh_verts_faces(0)
    
    final_vertices = final_vertices * vertices_max_val.to(device) + vertices_mean.to(device)
    
    return final_vertices.to('cpu') , final_faces.to('cpu')

In [88]:
## Running the script ##

final_vertices , final_faces = fit(3000)

Amazing. Our mesh is very nicely fitted.

SO, lets save our final .obj.

In [98]:
## Saving our .obj ##

save_obj('./fitted_mesh_final.obj' , final_vertices , final_faces)

print('Obj Saved Successfully!!')

In [99]:
## Downloading the obj file ##

import os
os.chdir(r'/kaggle/working')

!tar -czf obj.tar.gz fitted_mesh_final.obj

from IPython.display import FileLink

FileLink(r'obj.tar.gz')