<a href="https://colab.research.google.com/github/Aydin-ab/CV_DMTet/blob/main/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Installing Dependencies: Kaolin and PyTorch3D

In [1]:
from google.colab import drive, auth
drive.mount('/content/drive')
from pathlib import Path
INSTALL_PATH = Path("/content/drive/MyDrive/CV_DMTet/")
%cd  $INSTALL_PATH
%env XDG_CACHE_HOME=/content/drive/MyDrive/CV_DMTet

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/CV_DMTet
env: XDG_CACHE_HOME=/content/drive/MyDrive/CV_DMTet


In [2]:
# reinstall cython, install usd-core (for 3D rendering), and clone into kaolin repo
!pip uninstall Cython --yes
import torch
!pip install  Cython==0.29.20  --quiet
!pip install  usd-core --quiet


Found existing installation: Cython 0.29.20
Uninstalling Cython-0.29.20:
  Successfully uninstalled Cython-0.29.20


In [3]:
# installing kaolin and check version
%env IGNORE_TORCH_VER=1
%env KAOLIN_INSTALL_EXPERIMENTAL=1
KAOLIN_PATH = INSTALL_PATH / "kaolin"
%cd $INSTALL_PATH
!if [ ! -d $KAOLIN_PATH ]; then git clone --recursive https://github.com/NVIDIAGameWorks/kaolin; fi;
%cd $KAOLIN_PATH
SETUP_CHECK = KAOLIN_PATH / "kaolin" / "version.py"
!echo Checking if $SETUP_CHECK exists
!if [ ! -f $SETUP_CHECK ]; then python setup.py develop; fi;
# !python -c "import kaolin; print(kaolin.__version__)"

env: IGNORE_TORCH_VER=1
env: KAOLIN_INSTALL_EXPERIMENTAL=1
/content/drive/MyDrive/CV_DMTet
/content/drive/MyDrive/CV_DMTet/kaolin
Checking if /content/drive/MyDrive/CV_DMTet/kaolin/kaolin/version.py exists


In [4]:
# !python setup.py install_lib install_scripts build
!python setup.py develop

Compiling kaolin/cython/ops/mesh/triangle_hash.pyx because it depends on /usr/local/lib/python3.8/dist-packages/Cython/Includes/libcpp/vector.pxd.
Compiling kaolin/cython/ops/conversions/mise.pyx because it depends on /usr/local/lib/python3.8/dist-packages/Cython/Includes/libcpp/vector.pxd.
[1/2] Cythonizing kaolin/cython/ops/conversions/mise.pyx
  tree = Parsing.p_module(s, pxd, full_module_name)
[2/2] Cythonizing kaolin/cython/ops/mesh/triangle_hash.pyx
  tree = Parsing.p_module(s, pxd, full_module_name)
running develop
running egg_info
writing kaolin.egg-info/PKG-INFO
writing dependency_links to kaolin.egg-info/dependency_links.txt
writing requirements to kaolin.egg-info/requires.txt
writing top-level names to kaolin.egg-info/top_level.txt
reading manifest template 'MANIFEST.in'
adding license file 'LICENSE'
writing manifest file 'kaolin.egg-info/SOURCES.txt'
running build_ext
building 'kaolin.ops.mesh.triangle_hash' extension
x86_64-linux-gnu-gcc -pthread -Wno-unused-result -Wsign-

# Import packages

In [5]:
import numpy as np
import torch
import kaolin
import sys
import os

need_pytorch3d=False
try:
    import pytorch3d
except ModuleNotFoundError:
    need_pytorch3d=True
if need_pytorch3d:
    if torch.__version__.startswith("1.12.") and sys.platform.startswith("linux"):
        # We try to install PyTorch3D via a released wheel.
        pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
        version_str="".join([
            f"py3{sys.version_info.minor}_cu",
            torch.version.cuda.replace(".",""),
            f"_pyt{pyt_version_str}"
        ])
        print(f"version_str : {version_str}")
        !pip install fvcore iopath
        !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
    else:
        # We try to install PyTorch3D from source.
        !curl -LO https://github.com/NVIDIA/cub/archive/1.10.0.tar.gz
        !tar xzf 1.10.0.tar.gz
        os.environ["CUB_HOME"] = os.getcwd() + "/cub-1.10.0"
        !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'

from kaolin.ops.conversions import (
    trianglemeshes_to_voxelgrids,
    marching_tetrahedra,
    voxelgrids_to_cubic_meshes,
    voxelgrids_to_trianglemeshes,
)

from kaolin.ops.mesh import (
    index_vertices_by_faces
)

from kaolin.io.shapenet import (
    ShapeNetV2
)

from kaolin.metrics.trianglemesh import (
    point_to_mesh_distance,

)

from torch.utils.data import DataLoader

# add path for demo utils functions 
sys.path.append(os.path.abspath(''))
sys.path.append('/content/drive/MyDrive/CV_DMTet/')

# Setup dashboard to 

In [147]:
#Use pyngrok to access localhost:80 on Colab

!pip install pyngrok --quiet 
from pyngrok import ngrok

# Terminate open tunnels if exist
ngrok.kill()

# Setting the authtoken (optional)
# Get authtoken from https://dashboard.ngrok.com/auth
NGROK_AUTH_TOKEN = "2Hzzzh94FgOXssVkSP5Yffz8uYg_By2RMDZLTPx1aXakhYfH"
ngrok.set_auth_token(NGROK_AUTH_TOKEN)

In [148]:
#generating a public url mapped to localhost 80
public_url = ngrok.connect(port=80, proto="http", options={"bind_tls": True, "local": True})
print("Tracking URL:", public_url)

#Start Kaolin Dash3D on localhost:80 
# Can run this in terminal to not interfere with notebook
# !kaolin-dash3d --logdir=/content/drive/MyDrive/CV_DMTet/Logs --port=80

Tracking URL: NgrokTunnel: "http://52ee-34-141-173-195.ngrok.io" -> "http://localhost:80"


# Import Dataset: Subset of ShapeNetV2

In [6]:
import pytorch3d

In [7]:
# state_dict = torch.load('/content/drive/MyDrive/CV_DMTet/shapenet.pvcnn.c1.pth.tar')
# state_dict.keys()

In [8]:
sys.path.append(INSTALL_PATH / "examples")
sys.path.append(INSTALL_PATH / "examples" / "tutorial")

In [9]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")
device

device(type='cuda', index=0)

In [10]:
# # path to the point cloud to be reconstructed
# pcd_path = KAOLIN_PATH / "examples/samples/bear_pointcloud.usd"
# # path to the output logs (readable with the training visualizer in the omniverse app)
logs_path = '/content/drive/MyDrive/CV_DMTet/Logs'

# # We initialize the timelapse that will store USD for the visualization apps
timelapse = kaolin.visualize.Timelapse(logs_path)

#Load Tetrahedral grid

DMTet starts from a uniform tetrahedral grid of predefined resolution, and uses a network to predict the SDF value as well as deviation vector at each grid vertex.

Here we load the pre-generated tetrahedral grid using Quartet at resolution 128, which has roughly the same number of vertices as a voxel grid of resolution 65. We use a simple MLP + positional encoding to predict the SDF and deviation vectors in DMTet, and initialize the encoded SDF to represent a sphere.

In [565]:
# Uniform Tetrahedral Grid
tets_verts = torch.tensor(np.load(KAOLIN_PATH / "examples/samples/128_verts.npz")['data'], dtype=torch.float, device=device)
tets = torch.tensor(([np.load(KAOLIN_PATH / 'examples/samples/128_tets_{}.npz'.format(i))['data'] for i in range(4)]), dtype=torch.long, device=device).permute(1,0)
print(tets_verts, tets)


tensor([[ 0.5000,  0.5000,  0.4844],
        [ 0.4844,  0.5000,  0.4922],
        [ 0.4922,  0.4844,  0.4844],
        ...,
        [-0.1719, -0.5000,  0.4766],
        [-0.1562, -0.5000,  0.4688],
        [-0.1562, -0.4922,  0.4688]], device='cuda:0') tensor([[     0,      1,      2,      3],
        [     2,      3,      1,      4],
        [     5,      3,      0,      2],
        ...,
        [277409, 272920, 272914, 272919],
        [272919, 277409, 272920, 274866],
        [277409, 277400, 272920, 274866]], device='cuda:0')


# Loading from ShapeNet

In [726]:
# Not using for now, using libigl tutorial x-cylinder below
SHAPENET_PATH = "/content/drive/MyDrive/CV_DMTet/Core"
#SHAPENET_PATH = "/content/drive/MyDrive/FALL 2022/Computer Vision/Project/Core"
# SYNSETS_IDS = ['02747177', '02773838', '02801938', '02808440', '02818832', '02828884', '02843684'] #'02871439', '02876657', '02880940', '02924116', '02933112']
SYNSETS_IDS = ['02808440']
shapenet_train = ShapeNetV2(SHAPENET_PATH, categories=SYNSETS_IDS, output_dict=True)
shapenet_test = ShapeNetV2(SHAPENET_PATH, categories=SYNSETS_IDS, output_dict=True, train=False)

In [727]:
def get_next_shapenet(idx, object="bench"):
  idx +=1
  while 1:
    try:
      gt_model = shapenet_train[idx]["mesh"]
      break
    except:
      idx+=1
  
  gt_verts = gt_model[0].to(device)
  gt_faces = gt_model[1].to(device)
  gt_verts = ((gt_verts - ((gt_verts.max(0)[0] + gt_verts.min(0)[0]) / 2)) / ((gt_verts.max(0)[0] - gt_verts.min(0)[0]).max()))* 0.8

  wt_grid = kaolin.ops.conversions.trianglemeshes_to_voxelgrids(
        vertices=gt_verts.unsqueeze(0),
        faces=gt_faces,
        resolution=64
    )
  wt_verts, wt_faces = kaolin.ops.conversions.voxelgrids_to_cubic_meshes(wt_grid)
  wt_verts, wt_faces = wt_verts[0], wt_faces[0]

  center = (wt_verts.max(0)[0] + wt_verts.min(0)[0]) / 2
  max_l = (wt_verts.max(0)[0] - wt_verts.min(0)[0]).max()
  wt_verts = ((wt_verts - center) / max_l)* 0.8

  points = kaolin.ops.mesh.sample_points(gt_verts.unsqueeze(0), gt_faces, 5000)[0][0]
  center = (points.max(0)[0] + points.min(0)[0]) / 2
  max_l = (points.max(0)[0] - points.min(0)[0]).max()
  points = ((points - center) / max_l)* 0.9

  timelapse.add_mesh_batch(
      category=f'gt_bench_{idx}',
      vertices_list=[gt_verts.cpu()],
      faces_list=[gt_faces.cpu()]
  )
  gt_verts.to(device)
  gt_faces.to(device)
  wt_grid.to(device)
  wt_verts.to(device)
  wt_faces.to(device)
  points.to(device)
  return idx, gt_verts, gt_faces, wt_grid, wt_verts, wt_faces, points

idx, gt_verts, gt_faces, wt_grid, wt_verts, wt_faces, points = get_next_shapenet(-1)

# Convert model to watertight meshes

In [728]:
def group_shapenet(idx, groupnum, obj="bench"):
  group = []
  for g in range(groupnum):
    idx, gt_verts, gt_faces, wt_grid, wt_verts, wt_faces, points = get_next_shapenet(idx, obj)
    dct = {"idx": idx, "gt_verts": gt_verts, "gt_faces": gt_faces, "wt_grid": wt_grid, 
           "wt_verts": wt_verts, "wt_faces": wt_faces, "points" : points}
    group.append(dct)
  return group


In [730]:
group = group_shapenet(-1, 4, "bathtub")

In [569]:
# # Clone sample meshes from libigl tutorial 1x
# LIBIGL_TUTORIAL_DATA = INSTALL_PATH / "libigl-tutorial-data"
# print(f"LIBIGL_TUTORIAL_DATA: {LIBIGL_TUTORIAL_DATA}")
# %cd $INSTALL_PATH
# !if [ ! -d $LIBIGL_TUTORIAL_DATA ]; then git clone --recursive https://github.com/libigl/libigl-tutorial-data.git; fi;
# %cd $KAOLIN_PATH

In [570]:
# Load cylinder (along x-axis) mesh (V,F) = (42,80)
# mesh = kaolin.io.obj.import_mesh(LIBIGL_TUTORIAL_DATA / "arm.obj") #xcylinder.obj
# gt_verts = mesh[0].to(device)
# gt_faces = mesh[1].to(device)
# timelapse.add_mesh_batch(
#     category='gt',
#     vertices_list=[gt_verts.cpu()],
#     faces_list=[gt_faces.cpu()]
# )
# gt_sample = int(0.99*gt_verts.shape[0])

# Convert model to watertight meshes

We used a voxelization with resolution of 64 to predict the sdf and extract surface. 

In [571]:
# wt_grid
# Convert mesh (V,F) to a voxel grid 
# voxels = kaolin.ops.conversions.trianglemeshes_to_voxelgrids(
#     vertices=mesh_verts.unsqueeze(0).to(device),
#     faces=mesh_faces.to(device),
#     resolution=16
# )

# # Convert the voxel grid back into a mesh as GroundTruth
# # I am not sure this step is necessary
# # The intention is to have mesh_vertices, mesh_faces for the voxel grid for loss function
# wt_verts, wt_faces = kaolin.ops.conversions.voxelgrids_to_cubic_meshes(voxels)
# wt_verts, wt_faces = wt_verts[0], wt_faces[0]

# max_gt_len = (mesh_verts.max(0)[0] - mesh_faces.min(0)[0]).max()
# max_vox_len = (wt_verts.max(0)[0] - wt_verts.min(0)[0]).max()
# scale = max_gt_len / max_vox_len
# center = (wt_verts.max(0)[0] + wt_verts.min(0)[0]) / 2
# wt_verts = ((wt_verts - center) * scale)

# timelapse.add_mesh_batch(
#     category='watertight_test',
#     vertices_list=[wt_verts.cpu()],
#     faces_list=[wt_faces.cpu()]
# )

In [572]:
# Using Kaolin bear point cloud data (PCD) right now since tutorial uses PCD
# pcd_path = KAOLIN_PATH / "examples/samples/bear_pointcloud.usd"
# points = kaolin.io.usd.import_pointclouds(str(pcd_path))[0].points.to(device)
# if points.shape[0] > 100000:
#     idx = list(range(points.shape[0]))
#     np.random.shuffle(idx)
#     idx = torch.tensor(idx[:100000], device=points.device, dtype=torch.long)    
#     points = points[idx]

# # The reconstructed object needs to be slightly smaller than the grid to get watertight surface after MT.
# # The idea is that we want our point cloud to expand since it will converge faster if all face expansions are outward.
# center = (points.max(0)[0] + points.min(0)[0]) / 2
# max_l = (points.max(0)[0] - points.min(0)[0]).max()
# points = ((points - center) / max_l)* 0.9
# timelapse.add_pointcloud_batch(category='input',
#                                pointcloud_list=[points.cpu()], points_type = "usd_geom_points")

kaolin.ops.conversions.voxelgrids_to_cubic_meshes(

Convert voxelgrids to meshes by replacing each occupied voxel with a cuboid mesh (unit cube). Each cube has 8 vertices and 6 (for quadmesh) or 12 faces (for triangular mesh). Internal faces are ignored. If is_trimesh==True, this function performs the same operation as “Cubify” defined in the ICCV 2019 paper “Mesh R-CNN”: https://arxiv.org/abs/1906.02739.

Parameters
voxelgrids (torch.Tensor) – binary voxel array, of shape .

is_trimesh (optional, bool) – if True, the outputs are triangular meshes. Otherwise quadmeshes are returned. Default: True.

Returns
The list of vertices for each mesh.

The list of faces for each mesh.

Return type
(list[torch.Tensor], list[torch.LongTensor])

kaolin.ops.conversions.marching_tetrahedra(vertices, tets, sdf, return_tet_idx=False)¶
Convert discrete signed distance fields encoded on tetrahedral grids to triangle meshes using marching tetrahedra algorithm as described in An efficient method of triangulating equi-valued surfaces by using tetrahedral cells. The output surface is differentiable with respect to input vertex positions and the SDF values. For more details and example usage in learning, see Deep Marching Tetrahedra: a Hybrid Representation for High-Resolution 3D Shape Synthesis NeurIPS 2021.


# SDF model

We follow the paper recommandation and use a four-layer
MLPs with hidden dimensions 256, 256, 128 and 64, respectively

In [731]:
# Since we skip PVCNN, input dimension is just the coordinates of each grid

SDF_MLP_CONFIG = {
    'input_dim' : 3 + 3, # Coordinates of the grid's vertices
    'hidden_dims' : [256, 256, 128, 64],
    'output_dim' : 1, # SDF of the vertex input. The other "output" f_v comes from the prior activation layer of dimension 64
    'multires': 2
}

ENCODER_MLP_CONFIG = {
    'input_dim' : 3, # Coordinates of the grid's vertices
    'hidden_dims' : [256, 800, 1600, 1600],
    'output_dim' : 832, # SDF of the vertex input. The other "output" f_v comes from the prior activation layer of dimension 64
}

lr = 0.001
laplacian_weight = 0.1
iterations = 3000
save_every = 100
multires = 2
grid_res = 128

In [732]:
def get_pred_sdfs(model, tets_verts, gt_verts):
  model.eval()

  with torch.no_grad():
    F_vol = F.interpolate(gt_verts[None, None, None,:, :], size=(1,tets_verts.shape[0], tets_verts.shape[1]), mode='trilinear', align_corners=False)

    F_vol = torch.cat((tets_verts, F_vol.squeeze(0).squeeze(0).squeeze(0).to(device)), dim=1)
    pred_sdfs, f_vs = model(F_vol)

  return pred_sdfs, f_vs

# Get ground truth sdf from input verts and faces
# input shape: [batch_size, num_vertices, 3], [num_faces, 4], [batch_size, num_points, 3]
# output shape: [num_points, 1]
def get_gt_sdfs(gt_verts, gt_faces, points, f=None):

  if f == None:
    f = index_vertices_by_faces(gt_verts, gt_faces)
  d,_,_ = point_to_mesh_distance(points, f)

  s = kaolin.ops.mesh.check_sign(gt_verts, gt_faces, points)
  d[s==True] *= -1

  return d.squeeze(0).unsqueeze(1)

In [733]:
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from pytorch3d.utils import ico_sphere

# Comes from the DMTet tutorial 
# MLP + Positional Encoding
class Decoder(torch.nn.Module):
    def __init__(self, input_dims = 6, internal_dims = 64, output_dims = 4, hidden = 5, multires = 2):
        super().__init__()
        self.embed_fn = None
        self.pre_train_dims = input_dims
        if multires > 0:
            embed_fn, input_ch = get_embedder(multires)
            self.embed_fn = embed_fn
            input_dims = input_ch 
        self.input_dim = input_dims

        # Hidden layers
        self.hiddens = nn.ModuleList()
        net = (torch.nn.Linear(self.input_dim, internal_dims, bias=False), torch.nn.ReLU())
        for i in range(hidden-1):
            net = net + (torch.nn.Linear(internal_dims, internal_dims, bias=False), torch.nn.ReLU())
        self.net = torch.nn.Sequential(*net)
        self.output = torch.nn.Linear(internal_dims, output_dims, bias=False)

    def forward(self, p):
        if self.embed_fn is not None:
            p = self.embed_fn(p)
        p = self.net(p)
        out = self.output(p)
        return out, p
    
    def pre_train_sphere(self, iter):    
        print ("Initialize SDF to sphere")
        loss_fn = torch.nn.MSELoss()
        optimizer = torch.optim.Adam(list(self.parameters()), lr=1e-4)
        

        for i in tqdm(range(iter)):
            p = torch.rand((1024, self.pre_train_dims), device=device) - 0.5
            ref_value  = torch.sqrt(((p+.5)**2).sum(-1)) - 0.3
            output, _ = self(p)
            loss = loss_fn(output[...,0], ref_value)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print("Pre-trained MLP", loss.item())

# Positional Encoding from https://github.com/yenchenlin/nerf-pytorch/blob/1f064835d2cca26e4df2d7d130daa39a8cee1795/run_nerf_helpers.py
class Embedder:
    def __init__(self, **kwargs):
        self.kwargs = kwargs
        self.create_embedding_fn()
        
    def create_embedding_fn(self):
        embed_fns = []
        d = self.kwargs['input_dims']
        out_dim = 0
        if self.kwargs['include_input']:
            embed_fns.append(lambda x : x)
            out_dim += d
            
        max_freq = self.kwargs['max_freq_log2']
        N_freqs = self.kwargs['num_freqs']
        
        if self.kwargs['log_sampling']:
            freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
        else:
            freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
            
        for freq in freq_bands:
            for p_fn in self.kwargs['periodic_fns']:
                embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
                out_dim += d
                    
        self.embed_fns = embed_fns
        self.out_dim = out_dim
        
    def embed(self, inputs):
        return torch.cat([fn(inputs) for fn in self.embed_fns], -1)

def get_embedder(multires):
    embed_kwargs = {
                'include_input' : True,
                'input_dims' : 6,
                'max_freq_log2' : multires-1,
                'num_freqs' : multires,
                'log_sampling' : True,
                'periodic_fns' : [torch.sin, torch.cos],
    }
    
    embedder_obj = Embedder(**embed_kwargs)
    embed = lambda x, eo=embedder_obj : eo.embed(x)
    return embed, embedder_obj.out_dim

class MLP(torch.nn.Module):
    def __init__(self, config):
        super(MLP, self).__init__()

        self.input_dim = config['input_dim']
        self.hidden_dims  = config['hidden_dims']
        self.output_dim = config['output_dim']

        # Hidden layers
        self.hiddens = nn.ModuleList()
        in_dim = self.input_dim
        for k in range(len(self.hidden_dims)):
            self.hiddens.append(nn.Linear(in_dim, self.hidden_dims[k]))
            in_dim = self.hidden_dims[k]

        # Output layer
        self.output_layer = torch.nn.Linear(self.hidden_dims[-1], self.output_dim)

    def forward(self, x):
        for hidden in self.hiddens :
            x = F.relu(hidden(x))
        output = self.output_layer(x) # No activation (linear) cuz we do regression

        return output, x # Return output + last feature layer vector

    def pre_train_sphere(self, iter):    
        print ("Initialize SDF to sphere")
        loss_fn = torch.nn.MSELoss()
        optimizer = torch.optim.Adam(list(self.parameters()), lr=1e-4)
        

        for i in tqdm(range(iter)):
            p = torch.rand((1024,self.input_dim), device=device) - 0.5
            ref_value  = torch.sqrt(((p+.5)**2).sum(-1)) - 0.3
            output, _ = self(p)
            loss = loss_fn(output[...,0], ref_value)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print("Pre-trained MLP", loss.item())

In [734]:
r = Decoder(multires=2).to(device)
# r.pre_train_sphere(100)

In [735]:
ref_value.shape

torch.Size([1024])

In [736]:
%env CUDA_LAUNCH_BLOCKING=1

env: CUDA_LAUNCH_BLOCKING=1


In [737]:
from tqdm import tqdm
# # Initialize model and create optimizer

# sdf_model = Decoder(multires=2).to(device)
# sdf_model.pre_train_sphere(1000)

In [738]:
_, y = get_embedder(multires=0)
print(y)

6


In [739]:
encoder_model = MLP(SDF_MLP_CONFIG).to(device)
print(encoder_model)
# print('\n\n')
# summary(encoder_model, input_size=tets_verts.shape + 3)

MLP(
  (hiddens): ModuleList(
    (0): Linear(in_features=6, out_features=256, bias=True)
    (1): Linear(in_features=256, out_features=256, bias=True)
    (2): Linear(in_features=256, out_features=128, bias=True)
    (3): Linear(in_features=128, out_features=64, bias=True)
  )
  (output_layer): Linear(in_features=64, out_features=1, bias=True)
)


Select SDF Model
* sdf_model is standard MLP (all hidden dims are same) + f_v + comes with encoder deformer
* encoder_model is from DMTet and encodes layers + f_v + no deformer

In [740]:
MODEL = encoder_model

Little Test:

In [742]:
# signed distance fields encoded on tetrahedral grid

# pred_sdfs dim = (tets_vertices.shape[0])
# f_vs ie f_v feature vector

# sdf, delta x_i, delta y_i, delta z_i, 
# sdf = binary occupancy of tetrahedron 
# pred, _ = MODEL(tets_verts)
print(f'pred shape is : {pred.shape}')

# pred_sdfs, f_vs = sdf_model(tets_verts)

print(f'Input grid shape is : {tets_verts.shape}')
# print(f'Output shape of the predicted SDFs should be {tets_verts.shape[0], 1} and it actually is {tuple(pred_sdfs.shape)}')
# print(f'Output shape of the feature vectors f_vs should be {tets_verts.shape[0], 64} and it actually is {tuple(f_vs.shape)}')
# print(pred_sdfs)

pred shape is : torch.Size([277410, 4])
Input grid shape is : torch.Size([277410, 3])


# Set up Optimizer

In [266]:
sdf_vars = [p for _, p in MODEL.named_parameters()]
sdf_optimizer = torch.optim.Adam(sdf_vars, lr=lr)
sdf_scheduler = torch.optim.lr_scheduler.LambdaLR(sdf_optimizer, lr_lambda=lambda x: max(0.0, 10**(-x*0.0002))) # LR decay over time

# Training

In [267]:
# # # # takes in a module and applies the specified weight initialization
# def weights_init_normal(m):
#     '''Takes in a module and initializes all linear layers with weight
#         values taken from a normal distribution.'''
#     classname = m.__class__.__name__
#     # for every Linear layer in a model
#     if classname.find('Linear') != -1:
#         y = m.in_features
#     # m.weight.data shoud be taken from a normal distribution
#         m.weight.data.normal_(0.0,1/np.sqrt(y))
#         m.weight.data = torch.randint(-1,2,m.weight.shape)
#     # m.bias.data should be 0
#         if m.bias is not None:
#           m.bias.data.fill_(0)


2

In [616]:

# sdf_model.apply(weights_init_normal)

In [269]:
# sdf_model.net[10].weight.data

In [649]:
# Laplacian regularization using umbrella operator (Fujiwara / Desbrun).
# https://mgarland.org/class/geom04/material/smoothing.pdf
from pytorch3d import loss
from random import randint
def laplace_regularizer_const(pred_mesh_verts, pred_mesh_faces):
    term = torch.zeros_like(pred_mesh_verts, device=device)
    norm = torch.zeros_like(pred_mesh_verts[..., 0:1], device=device)

    v0 = pred_mesh_verts[pred_mesh_faces[:, 0], :]
    v1 = pred_mesh_verts[pred_mesh_faces[:, 1], :]
    v2 = pred_mesh_verts[pred_mesh_faces[:, 2], :]

    term.scatter_add_(0, pred_mesh_faces[:, 0:1].repeat(1,3), (v1 - v0) + (v2 - v0))
    term.scatter_add_(0, pred_mesh_faces[:, 1:2].repeat(1,3), (v0 - v1) + (v2 - v1))
    term.scatter_add_(0, pred_mesh_faces[:, 2:3].repeat(1,3), (v0 - v2) + (v1 - v2))

    two = torch.ones_like(v0) * 2.0
    norm.scatter_add_(0, pred_mesh_faces[:, 0:1], two)
    norm.scatter_add_(0, pred_mesh_faces[:, 1:2], two)
    norm.scatter_add_(0, pred_mesh_faces[:, 2:3], two)

    term = term / torch.clamp(norm, min=1.0)

    return torch.mean(term**2)

def loss_f(pred_mesh_verts, pred_mesh_faces, pts, it):
    pred_points = kaolin.ops.mesh.sample_points(pred_mesh_verts.unsqueeze(0), 
                                                pred_mesh_faces, 5000)[0][0]
    
    chamfer = kaolin.metrics.pointcloud.chamfer_distance(pred_points.unsqueeze(0), pts.unsqueeze(0)).mean()
    # chamfer.clone().detach().requires_grad_(True)
    if it > iterations//2:
        lap = laplace_regularizer_const(pred_mesh_verts, pred_mesh_faces)
        return chamfer + lap * laplacian_weight
    return chamfer 


def sdf_train(iterations, model, optimizer, scheduler):
  # Set to training mode
  model.train()

  for it in range(iterations):
      gidx = randint(0, len(group) - 1)
      points = group[gidx]["points"]

      wt_verts = group[gidx]["wt_verts"]

      F_vol = F.interpolate(points[None, None, None,:, :], size=(1,tets_verts.shape[0], tets_verts.shape[1]), mode='trilinear', align_corners=False)
      F_vol = torch.cat((tets_verts, F_vol.squeeze(0).squeeze(0).squeeze(0).to(device)), dim=1)
      
      pred, _ = model(F_vol)
      
      # Replace pred_sdfs with the new SDF
      # Replace deform with surface refinement 

      pred_sdfs, deform = pred[:,0], pred[:,1:]
      verts_deformed = tets_verts + torch.tanh(deform) / grid_res # constraint deformation to avoid flipping tets
      pred_mesh_verts, pred_mesh_faces = marching_tetrahedra(verts_deformed.unsqueeze(0), tets, pred_sdfs.unsqueeze(0))
      pred_mesh_verts, pred_mesh_faces = pred_mesh_verts[0], pred_mesh_faces[0]
      if pred_mesh_faces.shape[0] == 0:
        pred_sdfs[tets[0,0]] = -1
        pred_sdfs[tets[0,1]] = 1
        pred_sdfs[tets[0,2]] = 1
        pred_sdfs[tets[0,3]] = 1

        pred_mesh_verts, pred_mesh_faces = marching_tetrahedra(verts_deformed.unsqueeze(0), tets, pred_sdfs.unsqueeze(0))
        pred_mesh_verts, pred_mesh_faces = pred_mesh_verts[0], pred_mesh_faces[0]

      loss = loss_f(pred_mesh_verts, pred_mesh_faces, wt_verts, it)
      
      optimizer.zero_grad()
      torch.autograd.set_detect_anomaly(True)
      loss.backward(retain_graph=True)
      optimizer.step()
      scheduler.step()

      if (it) % save_every == 0 or it == (iterations - 1): 
          print ('Iteration {} - loss: {}'.format(it, loss))
          # save reconstructed mesh
          timelapse.add_mesh_batch(
              iteration=it+1,
              category='extracted_mesh',
              vertices_list=[pred_mesh_verts.cpu()],
              faces_list=[pred_mesh_faces.cpu()]
          )

def encoder_train(iterations, model, optimizer, scheduler, tv = None):
  # Set to training mode
  model.train()

  if tv is None:
    tv = tets_verts

  for it in range(iterations):
      pred_sdfs, _ = model(tv)
      
      gt_sdfs = get_gt_sdfs(wt_verts.unsqueeze(0), wt_faces, tets_verts.unsqueeze(0))

      loss = F.mse_loss(pred_sdfs, gt_sdfs) 
      
      optimizer.zero_grad()
      torch.autograd.set_detect_anomaly(True)
      loss.backward(retain_graph=True)
      optimizer.step()
      scheduler.step()

      if (it) % save_every == 0 or it == (iterations - 1): 
          print ('Iteration {} - loss: {}'.format(it, loss))
          # save reconstructed mesh
          # timelapse.add_mesh_batch(
          #     iteration=it+1,
          #     category='extracted_mesh',
          #     vertices_list=[pred_mesh_verts.cpu()],
          #     faces_list=[pred_mesh_faces.cpu()]
          # )


#Hide training MLP

In [271]:
# MODEL.pre_train_sphere(1000)
# sdf_vars = [p for _, p in MODEL.named_parameters()]
# sdf_optimizer = torch.optim.Adam(sdf_vars, lr=lr)
# sdf_scheduler = torch.optim.lr_scheduler.LambdaLR(sdf_optimizer, lr_lambda=lambda x: max(0.0, 10**(-x*0.0002))) # LR decay over time
# # ~12 min. Speed up?
# encoder_train(iterations, MODEL, sdf_optimizer, sdf_scheduler)


 # Surface refinement utils

In [272]:
# Get edges lists of shape [E, 2] from face list of shape [V, 4]

def get_edges(input):
  c = torch.combinations(torch.arange(input.size(1)), r=2)
  x = input[:,None].expand(-1,len(c),-1).cpu()
  idx = c[None].expand(len(x), -1, -1)
  x = x.gather(dim=2, index=idx)

  return x.view(-1, *x.shape[2:])

# Extract tets under certain sdf restrictions:
# if thresh = 0, return all surface tetrahedrons
# if thresh > 0, return all tetrahedrons whose vertices' sdfs are all in the range [-thresh, thresh]

def extract_tet(tets, sdf, thresh, non_surf=False):

  assert thresh >= 0

  if thresh == 0:
    mask = sdf[tets] > 0
    mask_int = mask.squeeze(2).long()
    t = mask_int.sum(1)
    surf_tets = tets[(t > 0) & (t < 4)]
  else:
    mask = (sdf[tets] >= -thresh) & (sdf[tets] <= thresh)
    mask_int = mask.squeeze(2).long()
    t = mask_int.sum(1)
    surf_tets = tets[t == 4]

  surf_tets_tuple = surf_tets.unique(return_inverse=True)
  surf_tets_idx, surf_tets = surf_tets_tuple[0], surf_tets_tuple[1]

  if non_surf:
    if thresh == 0:
      non_surf_tets = tets[~((t > 0) & (t < 4))]
    else:
      non_surf_tets = tets[t < 4]
    non_surf_tets_tuple = non_surf_tets.unique(return_inverse=True)
    non_surf_tets_idx, non_surf_tets = non_surf_tets_tuple[0], non_surf_tets_tuple[1]
    return surf_tets_idx, surf_tets, non_surf_tets_idx, non_surf_tets

  return surf_tets_idx, surf_tets

#Get output from initial MLP



# Surface refinement model

In [744]:
"""
Graph-res net:
Identify surface tetrahedral, build adj matrix, 
"""

from tqdm import tqdm

from torch import nn
from torch.nn import functional as F
from pytorch3d.ops import GraphConv

# a single res block layer with dimension 256 & 128
class GResBlock(nn.Module): 
    def __init__(self, in_dim, hidden_dim, activation=None):
        super(GResBlock, self).__init__()

        self.conv1 = GraphConv(in_dim, hidden_dim)
        self.conv2 = GraphConv(hidden_dim, in_dim)
        self.activation = F.relu if activation else None
    
    def forward(self, inputs):
        input, adj = inputs[0], inputs[1]
        x = self.conv1(input, adj)
        if self.activation:
          x = self.activation(x)
        x = self.conv2(x, adj)
        if self.activation:
          x = self.activation(input + x)
        
        return [x, adj]

class GBottleneck(nn.Module):
    def __init__(self, block_num, in_dim, hidden_dim, out_dim, activation=None):
        super(GBottleneck, self).__init__()

        resblock_layers = [GResBlock(in_dim=hidden_dim[0], hidden_dim=hidden_dim[1], activation=activation)
                          for _ in range(block_num)]
        self.blocks = nn.Sequential(*resblock_layers)
        self.conv1 = GraphConv(in_dim, hidden_dim[0])

        self.activation = F.relu if activation else None
    
    def forward(self, inputs, adj):
        x = self.conv1(inputs, adj)
        if self.activation:
          x = self.activation(x)
        x = self.blocks([x, adj])[0]
        if self.activation:
          x = self.activation(x)

        return x

class GCN_Res(nn.Module):
    def __init__(self, config):
        super(GCN_Res, self).__init__()

        self.in_dim = config['in_dim']
        self.hidden_dim = config['hidden_dim']
        self.out_dim = config['out_dim']
        self.activation = config['activation']
        self.mlp_hdim = config['mlp_hdim']
        self.mlp_odim = config['mlp_odim']

        self.gcn_res = nn.ModuleList([GBottleneck(2, self.in_dim, self.hidden_dim, self.out_dim, self.activation)])

        self.sdf_mlp = nn.Sequential(
            nn.Linear(self.out_dim, self.mlp_hdim[0], bias=False),
            nn.Linear(self.mlp_hdim[0], self.mlp_hdim[1], bias=False),
            nn.Linear(self.mlp_hdim[1], 1, bias=False),
        )

        self.deform_mlp = nn.Sequential(
            nn.Linear(self.out_dim, self.mlp_hdim[0], bias=False),
            nn.Linear(self.mlp_hdim[0], self.mlp_hdim[1], bias=False),
            nn.Linear(self.mlp_hdim[1], 3, bias=False),
        )

        self.feature_mlp = nn.Sequential(
            nn.Linear(self.out_dim, self.mlp_hdim[0], bias=False),
            nn.Linear(self.mlp_hdim[0], self.mlp_hdim[1], bias=False),
            nn.Linear(self.mlp_hdim[1], self.mlp_hdim[1], bias=False),
        )
        


    def forward(self, inputs, adj):

        x = self.gcn_res[0](inputs, adj)

        sdf = self.sdf_mlp(x)
        deform = self.deform_mlp(x)
        deform = torch.tanh(deform)
        feature = self.feature_mlp(x)
        
        return sdf, deform, feature

# GCN Loss Function

In [745]:
from pytorch3d.loss import(
    chamfer_distance
)

def laplace_regularizer_const(mesh_verts, mesh_faces):
    term = torch.zeros_like(mesh_verts)
    norm = torch.zeros_like(mesh_verts[..., 0:1])

    v0 = mesh_verts[mesh_faces[:, 0], :]
    v1 = mesh_verts[mesh_faces[:, 1], :]
    v2 = mesh_verts[mesh_faces[:, 2], :]

    term.scatter_add_(0, mesh_faces[:, 0:1].repeat(1,3), (v1 - v0) + (v2 - v0))
    term.scatter_add_(0, mesh_faces[:, 1:2].repeat(1,3), (v0 - v1) + (v2 - v1))
    term.scatter_add_(0, mesh_faces[:, 2:3].repeat(1,3), (v0 - v2) + (v1 - v2))

    two = torch.ones_like(v0) * 2.0
    norm.scatter_add_(0, mesh_faces[:, 0:1], two)
    norm.scatter_add_(0, mesh_faces[:, 1:2], two)
    norm.scatter_add_(0, mesh_faces[:, 2:3], two)

    term = term / torch.clamp(norm, min=1.0)

    return torch.mean(term**2)

def gcn_loss(iterations, mesh_verts, mesh_faces, gt_verts, gt_faces, it):

    #surface alignment loss
    pm = pytorch3d.structures.Meshes([mesh_verts], [mesh_faces])
    gm = pytorch3d.structures.Meshes([gt_verts], [gt_faces])
  
    if mesh_verts.shape[0] > 0:
      pred_points = kaolin.ops.mesh.sample_points(mesh_verts.unsqueeze(0), mesh_faces, 100000)[0][0]
      gt_points = kaolin.ops.mesh.sample_points(gt_verts.unsqueeze(0), gt_faces, 100000)[0][0]
      chamfer = kaolin.metrics.pointcloud.chamfer_distance(pred_points.unsqueeze(0), gt_points.unsqueeze(0), squared=False).mean()
    else:
      chamfer = 0


    if it > iterations//2:
      lap = laplace_regularizer_const(mesh_verts, mesh_faces)
      return 500*chamfer + lap * laplacian_weight 
    return 500*chamfer

# Train GCN Model for Surface Refinement

In [753]:
CONFIG_GCNRES = {
    'in_dim': 68,
    'hidden_dim': [128, 256],
    'out_dim': 128,
    'activation': True,
    'mlp_hdim': [128,64],
    'mlp_odim': 68,
}

# Same set of Hyperparam is applied
lr = 1e-4
laplacian_weight = 0.1
gcn_iterations = 5000
save_every = 100
multires = 2
grid_res = 128

In [758]:
def test_train(iterations, sdf_model, gcn_model, optimizer, scheduler, epoch):
  gcn_model.train()
  sdf_model.train()
  avg_loss = 0
  wt_models = []

  for i in range(iterations):

    gidx = randint(0, len(group) - 1)
    g = group[gidx]

    wt_grid = g["wt_grid"]
    wt_verts = g["wt_verts"]
    wt_faces = g["wt_faces"]
    gt_verts = g["gt_verts"]
    gt_faces = g["gt_faces"]



    F_vol = F.interpolate(wt_grid.unsqueeze(0), size=(1,tets_verts.shape[0], tets_verts.shape[1]), mode='trilinear', align_corners=False)

    F_vol = torch.cat((tets_verts, F_vol.squeeze(0).squeeze(0).squeeze(0).to(device)), dim=1)

    pred_sdfs, f_vs = sdf_model(F_vol)

    """
    Surface Refinement
    """

    surf_tets_verts_idx, surf_tets_faces = extract_tet(tets, pred_sdfs, 0.02) #if not working modify the 0.008 here; this is the threshold for surface sdf value
    surf_tets_verts = torch.clone(tets_verts[surf_tets_verts_idx])

    surf_tets_verts_features = torch.clone(f_vs[surf_tets_verts_idx])
    surf_sdfs = pred_sdfs[surf_tets_verts_idx]
    surf_tets_edges = torch.clone(get_edges(surf_tets_faces).to(device))
    surf_verts_f = torch.cat((surf_tets_verts, surf_sdfs, surf_tets_verts_features), dim=1)

    sdf, deform, fv = gcn_model(surf_verts_f, surf_tets_edges)
    
    """

    Update surface position, sdf, and f_s

    """

    #updated sdf

    update_sdfs = pred_sdfs.clone()
    update_sdfs[surf_tets_verts_idx] += sdf

    #update vertices positions

    update_tets_verts = tets_verts.clone()
    update_tets_verts[surf_tets_verts_idx] += deform / grid_res

    #update vertices features

    update_tets_f = f_vs.clone()
    update_tets_f[surf_tets_verts_idx] += fv

    if epoch < 500:
      gt_sdfs = get_gt_sdfs(wt_verts.unsqueeze(0), wt_faces, update_tets_verts.unsqueeze(0))

      sdf_loss = F.mse_loss(update_sdfs, gt_sdfs, reduction='mean')

      optimizer.zero_grad()
      sdf_loss.backward(retain_graph=True)
      optimizer.step()
      avg_loss += sdf_loss.item()

      if epoch == 0 and i == 0:
        print('========== Start pretraining ==========')
      
      if i == (iterations - 1) and epoch % 100 == 0:
        print ('Epoch {} - loss: {}'.format(epoch, avg_loss/iterations))

      continue


    """

    Marching Tetrahedra based on new sdf value and deformed vertices in the tet grid

    """

    mesh_verts, mesh_faces = kaolin.ops.conversions.marching_tetrahedra(update_tets_verts.unsqueeze(0), tets, update_sdfs.squeeze(1).unsqueeze(0))
    mesh_verts, mesh_faces = mesh_verts[0], mesh_faces[0]

    """

    Compute Loss for First surface refinement: 
    Normal consistency + surface alignment + laplacian smooth + sdf L2-reg + deform L2-reg

    """

    # L2 sdf reg: 

    s_sdfs = get_gt_sdfs(gt_verts.unsqueeze(0), gt_faces, update_tets_verts.unsqueeze(0))

    mask = ((s_sdfs >= -0.3) & (s_sdfs <= 0.3)).squeeze(1)
    p = update_sdfs[mask]
    g = s_sdfs[mask]

    sdf_loss = F.mse_loss(p, g, reduction='mean') 

    #L2 deform reg

    deform_loss = F.mse_loss(update_tets_verts, tets_verts, reduction='mean')

    #surface alignment loss

    r_loss = gcn_loss(iterations, mesh_verts, mesh_faces, gt_verts, gt_faces, i)

    g_loss = r_loss + deform_loss + 0.4*sdf_loss

    optimizer.zero_grad()
    torch.autograd.set_detect_anomaly(True)
    g_loss.backward(retain_graph=True)
    avg_loss += g_loss.item()
    optimizer.step()
    scheduler.step()

    if epoch == 500 and i == 0:
        print('========== Start Refinement ==========')

    
    if epoch % 100 == 0:
      if (i) % 1 == 0: 
        # print ('Iteration {} - loss: {}, # of mesh vertices: {}, # of mesh faces: {}'.format(i, g_loss, mesh_verts.shape[0], mesh_faces.shape[0]))
        
        # save reconstructed mesh
        timelapse.add_mesh_batch(
            iteration=epoch+1,
            category='final_train_res',
            vertices_list=[mesh_verts.cpu()],
            faces_list=[mesh_faces.cpu()]
        )
      
      if i == (iterations - 1):
        print ('Epoch {} - loss: {}, # of mesh vertices: {}, # of mesh faces: {}'.format(epoch, avg_loss/iterations, mesh_verts.shape[0], mesh_faces.shape[0]))

#GCN Optimizer


In [759]:
GCN_SDF_MLP_CONFIG = {
    'input_dim' : 6, # Coordinates of the grid's vertices #previously 3
    'hidden_dims' : [256, 256, 128, 64],
    'output_dim' : 1, # SDF of the vertex input. The other "output" f_v comes from the prior activation layer of dimension 64
    'multires': 2
}
refine_model = GCN_Res(CONFIG_GCNRES).to(device)
sdf_model = MLP(GCN_SDF_MLP_CONFIG).to(device)

print(refine_model)

sdf_vars = [p for _, p in sdf_model.named_parameters()]
sdf_optimizer = torch.optim.Adam(sdf_vars, lr=lr)
sdf_scheduler = torch.optim.lr_scheduler.LambdaLR(sdf_optimizer, lr_lambda=lambda x: max(0.0, 10**(-x*0.0002))) # LR decay over time

params = list(sdf_model.named_parameters()) + list(refine_model.named_parameters())

refine_vars = [p for _, p in params]
refine_optimizer = torch.optim.Adam(refine_vars, lr=lr)
refine_scheduler = torch.optim.lr_scheduler.LambdaLR(refine_optimizer, lr_lambda=lambda x: max(0.0, 10**(-x*0.0002))) # LR decay over time

GCN_Res(
  (gcn_res): ModuleList(
    (0): GBottleneck(
      (blocks): Sequential(
        (0): GResBlock(
          (conv1): GraphConv(128 -> 256, directed=False)
          (conv2): GraphConv(256 -> 128, directed=False)
        )
        (1): GResBlock(
          (conv1): GraphConv(128 -> 256, directed=False)
          (conv2): GraphConv(256 -> 128, directed=False)
        )
      )
      (conv1): GraphConv(68 -> 128, directed=False)
    )
  )
  (sdf_mlp): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=False)
    (1): Linear(in_features=128, out_features=64, bias=False)
    (2): Linear(in_features=64, out_features=1, bias=False)
  )
  (deform_mlp): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=False)
    (1): Linear(in_features=128, out_features=64, bias=False)
    (2): Linear(in_features=64, out_features=3, bias=False)
  )
  (feature_mlp): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=False)
    (1): Linear(in_features

In [760]:
torch.cuda.empty_cache()
for i in tqdm(range(3000)):
  test_train(1, sdf_model, refine_model, refine_optimizer, refine_scheduler, i)

  0%|          | 1/3000 [00:00<24:33,  2.04it/s]

Epoch 0 - loss: 0.04032648727297783


  3%|▎         | 101/3000 [00:45<23:41,  2.04it/s]

Epoch 100 - loss: 0.03344113752245903


  7%|▋         | 201/3000 [01:36<28:18,  1.65it/s]

Epoch 200 - loss: 0.3127651810646057


 10%|█         | 301/3000 [02:56<45:49,  1.02s/it]

Epoch 300 - loss: 0.015374894253909588


 13%|█▎        | 401/3000 [04:47<51:15,  1.18s/it]

Epoch 400 - loss: 0.01308599766343832


 17%|█▋        | 500/3000 [06:54<57:02,  1.37s/it]



 17%|█▋        | 501/3000 [06:56<1:06:14,  1.59s/it]

Epoch 500 - loss: 158.60443115234375, # of mesh vertices: 197362, # of mesh faces: 389557


 20%|██        | 601/3000 [09:30<1:08:19,  1.71s/it]

Epoch 600 - loss: 120.02179718017578, # of mesh vertices: 81847, # of mesh faces: 161432


 23%|██▎       | 701/3000 [12:10<1:18:41,  2.05s/it]

Epoch 700 - loss: 54.75283432006836, # of mesh vertices: 47356, # of mesh faces: 94021


 27%|██▋       | 801/3000 [14:50<56:36,  1.54s/it]

Epoch 800 - loss: 28.627580642700195, # of mesh vertices: 28109, # of mesh faces: 56496


 30%|███       | 901/3000 [17:18<46:28,  1.33s/it]

Epoch 900 - loss: 80.23519134521484, # of mesh vertices: 60688, # of mesh faces: 120098


 33%|███▎      | 1001/3000 [19:02<41:44,  1.25s/it]

Epoch 1000 - loss: 14.30440616607666, # of mesh vertices: 32388, # of mesh faces: 65364


 37%|███▋      | 1101/3000 [20:59<30:59,  1.02it/s]

Epoch 1100 - loss: 13.212596893310547, # of mesh vertices: 27948, # of mesh faces: 56620


 40%|████      | 1201/3000 [22:36<28:50,  1.04it/s]

Epoch 1200 - loss: 12.941976547241211, # of mesh vertices: 31681, # of mesh faces: 64778


 43%|████▎     | 1301/3000 [24:22<31:15,  1.10s/it]

Epoch 1300 - loss: 14.330855369567871, # of mesh vertices: 37542, # of mesh faces: 76268


 47%|████▋     | 1401/3000 [25:52<34:10,  1.28s/it]

Epoch 1400 - loss: 18.272363662719727, # of mesh vertices: 41446, # of mesh faces: 83672


 50%|█████     | 1501/3000 [27:30<21:34,  1.16it/s]

Epoch 1500 - loss: 15.223045349121094, # of mesh vertices: 30970, # of mesh faces: 62564


 53%|█████▎    | 1601/3000 [29:14<24:04,  1.03s/it]

Epoch 1600 - loss: 15.061433792114258, # of mesh vertices: 39004, # of mesh faces: 78764


 57%|█████▋    | 1701/3000 [30:55<31:48,  1.47s/it]

Epoch 1700 - loss: 50.28614807128906, # of mesh vertices: 59404, # of mesh faces: 118672


 60%|██████    | 1801/3000 [32:13<21:47,  1.09s/it]

Epoch 1800 - loss: 20.553956985473633, # of mesh vertices: 35837, # of mesh faces: 71754


 63%|██████▎   | 1901/3000 [33:32<17:23,  1.05it/s]

Epoch 1900 - loss: 20.540023803710938, # of mesh vertices: 33660, # of mesh faces: 67924


 67%|██████▋   | 2001/3000 [34:48<12:44,  1.31it/s]

Epoch 2000 - loss: 19.583452224731445, # of mesh vertices: 36162, # of mesh faces: 73024


 70%|███████   | 2101/3000 [36:07<13:06,  1.14it/s]

Epoch 2100 - loss: 18.526134490966797, # of mesh vertices: 35644, # of mesh faces: 72488


 73%|███████▎  | 2201/3000 [37:29<12:06,  1.10it/s]

Epoch 2200 - loss: 14.334057807922363, # of mesh vertices: 40008, # of mesh faces: 81744


 77%|███████▋  | 2301/3000 [38:42<10:17,  1.13it/s]

Epoch 2300 - loss: 11.856159210205078, # of mesh vertices: 39330, # of mesh faces: 79800


 80%|████████  | 2401/3000 [39:56<09:24,  1.06it/s]

Epoch 2400 - loss: 11.3319091796875, # of mesh vertices: 38388, # of mesh faces: 77856


 83%|████████▎ | 2501/3000 [41:10<07:57,  1.05it/s]

Epoch 2500 - loss: 14.463262557983398, # of mesh vertices: 41242, # of mesh faces: 83920


 87%|████████▋ | 2601/3000 [42:23<05:17,  1.26it/s]

Epoch 2600 - loss: 16.731491088867188, # of mesh vertices: 35648, # of mesh faces: 72416


 90%|█████████ | 2701/3000 [43:42<04:18,  1.15it/s]

Epoch 2700 - loss: 13.373225212097168, # of mesh vertices: 39150, # of mesh faces: 79276


 93%|█████████▎| 2801/3000 [44:52<02:47,  1.18it/s]

Epoch 2800 - loss: 10.384215354919434, # of mesh vertices: 32560, # of mesh faces: 65920


 97%|█████████▋| 2901/3000 [46:10<02:04,  1.25s/it]

Epoch 2900 - loss: 17.10131072998047, # of mesh vertices: 41580, # of mesh faces: 84528


100%|██████████| 3000/3000 [47:17<00:00,  1.06it/s]


# Discriminator

In [None]:
# break

In [None]:
import math
# hyperparameters
N = 18 
r = 128
Kg_min = math.pi / 16
n_sample_patches = 10
# iterations = 20000 

In [None]:
import torch

def calculate_gaussian_curvature(vertices, faces):
    # Calculate the vertex normals
    v_normals = torch.zeros_like(vertices)
    for face in faces:
        # Get the vertices of the current face
        v1 = vertices[face[0]]
        v2 = vertices[face[1]]
        v3 = vertices[face[2]]

        # Calculate the face normal
        f_normal = torch.cross(v2 - v1, v3 - v1)
        f_normal = f_normal / f_normal.norm()

        # Add the face normal to the vertex normals of the vertices of the current face
        v_normals[face[0]] += f_normal
        v_normals[face[1]] += f_normal
        v_normals[face[2]] += f_normal

    # Normalize the vertex normals
    v_normals = v_normals / v_normals.norm(dim=1, keepdim=True)

    # Calculate the Gaussian curvature for each vertex
    gaussian_curvature = torch.zeros(vertices.shape[0])
    for i in range(vertices.shape[0]):
        # Get the neighbors of the current vertex
        neighbors = []
        for face in faces:
            if i in face:
                for j in face:
                    if j != i:
                        neighbors.append(j)
        neighbors = torch.tensor(list(set(neighbors)))

        # Calculate the angle sum at the current vertex
        angle_sum = 0
        for j in neighbors:
            # Calculate the angle between the normals of the current vertex and its neighbor
            angle = torch.acos(torch.clamp(torch.dot(v_normals[i], v_normals[j]), -1, 1))

            # Add the angle to the angle sum
            angle_sum += angle

        # Calculate the Gaussian curvature at the current vertex
        gaussian_curvature[i] = (2 * math.pi - angle_sum) / len(neighbors)

    return gaussian_curvature


In [218]:
DISCRIMINATOR_CONFIG = {
#     'input_dim' : 3, # Coordinates of the grid's vertices
    'kernel_size' : [4, 3, 3, 3, 3],
    'out_channel' : [32, 64, 128, 256, 512],
    'stride' : [1, 2, 1, 1, 1],
    'output_dim' : 512 # SDF of the vertex input. The other "output" f_v comes from the prior activation layer of dimension 64
}

In [219]:
# https://github.com/czq142857/DECOR-GAN/blob/3d736bf0f5bd9206cc26ee5336c1d7b4172f6cf8/evalFID.py#L63
class Discriminator(nn.Module):
    def __init__(self, config):
        super(Discriminator, self).__init__()
        # self.config = config
        self.voxel_size = 64

        self.conv_1 = nn.Conv3d(in_channels=1, out_channels=32, kernel_size=4, stride=1, bias=True)
        self.bn_1 = nn.InstanceNorm3d(32)

        self.conv_2 = nn.Conv3d(in_channels=32, out_channels=64, kernel_size=3, stride=2, bias=True)
        self.bn_2 = nn.InstanceNorm3d(64)

        self.conv_3 = nn.Conv3d(in_channels=64, out_channels=128, kernel_size=3, stride=1, bias=True)
        self.bn_3 = nn.InstanceNorm3d(128)

        self.conv_4 = nn.Conv3d(in_channels=128, out_channels=256, kernel_size=3, stride=1, bias=True)
        self.bn_4 = nn.InstanceNorm3d(256)

        self.conv_5 = nn.Conv3d(in_channels=256, out_channels=512, kernel_size=3, stride=1, bias=True)

#         if self.voxel_size==256:
#             self.bn_5 = nn.InstanceNorm3d(self.z_dim)
#             self.conv_5_2 = nn.Conv3d(self.z_dim, self.z_dim, 4, stride=2, padding=1, bias=True)

        self.linear1 = nn.Linear(512, 1, bias=True)



    def forward(self, inputs):
        out = inputs

        out = self.bn_1(self.conv_1(out))
        out = F.leaky_relu(out, negative_slope=0.01, inplace=True)

        out = self.bn_2(self.conv_2(out))
        out = F.leaky_relu(out, negative_slope=0.01, inplace=True)

        out = self.bn_3(self.conv_3(out))
        out = F.leaky_relu(out, negative_slope=0.01, inplace=True)

        out = self.bn_4(self.conv_4(out))
        out = F.leaky_relu(out, negative_slope=0.01, inplace=True)

        out = self.conv_5(out)

#         if self.voxel_size==256:
#             out = self.bn_5(out)
#             out = F.leaky_relu(out, negative_slope=0.01, inplace=True)
#             out = self.conv_5_2(out)

        z = F.adaptive_avg_pool3d(out, output_size=(1, 1, 1))
        z = z.view(-1,512)
        out = F.leaky_relu(z, negative_slope=0.01, inplace=True)
        
        # Add Fvol here (512) !!!
        
        out = self.linear1(out)

        return out, z

In [220]:
def filter_faces(faces, vertex_gaussians, threshold):
  face_mask = vertex_gaussians[faces[:,0]] > threshold & vertex_gaussians[faces[:,1]] > threshold & vertex_gaussians[faces[:,2]] > threshold
  return faces[face_mask]

In [642]:
# def test_train2(iterations, sdf_model, gcn_model, discrim_model, optimizer, scheduler, epoch):
#   gcn_model.train()
#   sdf_model.train()
#   discrim_model.train()
#   avg_loss = 0
#   wt_models = []

#   for i in range(iterations):

#     F_vol = F.interpolate(wt_grid.unsqueeze(0), size=(1,tets_verts.shape[0], tets_verts.shape[1]), mode='trilinear', align_corners=False)

#     F_vol = torch.cat((tets_verts, F_vol.squeeze(0).squeeze(0).squeeze(0).to(device)), dim=1)

#     pred_sdfs, f_vs = sdf_model(F_vol)

#     """
#     Surface Refinement
#     """

#     surf_tets_verts_idx, surf_tets_faces = extract_tet(tets, pred_sdfs, 0.008) #if not working modify the 0.008 here; this is the threshold for surface sdf value
#     surf_tets_verts = torch.clone(tets_verts[surf_tets_verts_idx])

#     surf_tets_verts_features = torch.clone(f_vs[surf_tets_verts_idx])
#     surf_sdfs = pred_sdfs[surf_tets_verts_idx]
#     surf_tets_edges = torch.clone(get_edges(surf_tets_faces).to(device))
#     surf_verts_f = torch.cat((surf_tets_verts, surf_sdfs, surf_tets_verts_features), dim=1)

#     sdf, deform, fv = gcn_model(surf_verts_f, surf_tets_edges)
    
#     """

#     Update surface position, sdf, and f_s

#     """

#     #updated sdf

#     update_sdfs = pred_sdfs.clone()
#     update_sdfs[surf_tets_verts_idx] += sdf

#     #update vertices positions

#     update_tets_verts = tets_verts.clone()
#     update_tets_verts[surf_tets_verts_idx] += deform / grid_res

#     #update vertices features

#     update_tets_f = f_vs.clone()
#     update_tets_f[surf_tets_verts_idx] += fv

#     #Generate sdf 


#     # shape of SDFs
#     Sreal = np.ones((N,N,N)) #mesh 
#     Sgt = np.ones((N,N,N))

#     # Kg is dimension Vgt x 1
#     Kg = calculate_gaussian_curvature(gt_verts, gt_faces).to(device)
#     # mask_curvature = Kg >= Kg_min

#     # # mask vertices
#     # V_gt = gt_verts[mask_curvature].clone()
#     # V_pred = update_tets_verts[mask_curvature].clone()

#     # remap faces to new 
#     F_gt = filter_faces(gt_faces, Kg, Kg_min)
#     F_pred = filter_faces(surf_tets_faces, Kg, Kg_min)


#     v_gt = kaolin.ops.mesh.sample_points(gt_verts, F_gt, 1)
#     v_pred = kaolin.ops.mesh.sample_points(update_tets_verts, F_pred, 1)

#     # N x N x N, N=18
#     vox_gt = kaolin.ops.conversions.trianglemeshes_to_voxelgrids(v_gt, F_gt, N)
#     vox_pred = kaolin.ops.conversions.trianglemeshes_to_voxelgrids(v_pred, F_pred, N)


#     pred_sdfs, f_vs = sdf_model(F_vol)



#     s_real = SDF(Vgt + (Vgt.shape[0] - N/2) / r, (vgt, fgt))
#     Sreal = SDF(Vgt + (Vgt.shape[0] - N/2) / r, (vgt, fgt))

#     if epoch < 500:
#       gt_sdfs = get_gt_sdfs(wt_verts.unsqueeze(0), wt_faces, update_tets_verts.unsqueeze(0))

#       sdf_loss = F.mse_loss(update_sdfs, gt_sdfs, reduction='mean')

#       optimizer.zero_grad()
#       sdf_loss.backward(retain_graph=True)
#       optimizer.step()
#       avg_loss += sdf_loss.item()

#       if epoch == 0 and i == 0:
#         print('========== Start pretraining ==========')
      
#       if i == (iterations - 1) and epoch % 100 == 0:
#         print ('Epoch {} - loss: {}'.format(epoch, avg_loss/iterations))

#       continue


#     """

#     Marching Tetrahedra based on new sdf value and deformed vertices in the tet grid

#     """

#     mesh_verts, mesh_faces = kaolin.ops.conversions.marching_tetrahedra(update_tets_verts.unsqueeze(0), tets, update_sdfs.squeeze(1).unsqueeze(0))
#     mesh_verts, mesh_faces = mesh_verts[0], mesh_faces[0]

#     """

#     Compute Loss for First surface refinement: 
#     Normal consistency + surface alignment + laplacian smooth + sdf L2-reg + deform L2-reg

#     """

#     # L2 sdf reg: 

#     s_sdfs = get_gt_sdfs(gt_verts.unsqueeze(0), gt_faces, update_tets_verts.unsqueeze(0))

#     mask = ((s_sdfs >= -0.3) & (s_sdfs <= 0.3)).squeeze(1)
#     p = update_sdfs[mask]
#     g = s_sdfs[mask]

#     sdf_loss = F.mse_loss(p, g, reduction='mean') 

#     #L2 deform reg

#     deform_loss = F.mse_loss(update_tets_verts, tets_verts, reduction='mean')

#     #surface alignment loss

#     r_loss = gcn_loss(iterations, mesh_verts, mesh_faces, gt_verts, gt_faces, i)

#     g_loss = r_loss + deform_loss + 0.4*sdf_loss

#     optimizer.zero_grad()
#     torch.autograd.set_detect_anomaly(True)
#     g_loss.backward(retain_graph=True)
#     avg_loss += g_loss.item()
#     optimizer.step()
#     scheduler.step()

#     if epoch == 500 and i == 0:
#         print('========== Start Refinement ==========')

    
#     if epoch % 100 == 0:
#       if (i) % 1 == 0: 
#         # print ('Iteration {} - loss: {}, # of mesh vertices: {}, # of mesh faces: {}'.format(i, g_loss, mesh_verts.shape[0], mesh_faces.shape[0]))
        
#         # save reconstructed mesh
#         timelapse.add_mesh_batch(
#             iteration=epoch+1,
#             category='final_train_res',
#             vertices_list=[mesh_verts.cpu()],
#             faces_list=[mesh_faces.cpu()]
#         )
      
#       if i == (iterations - 1):
#         print ('Epoch {} - loss: {}, # of mesh vertices: {}, # of mesh faces: {}'.format(epoch, avg_loss/iterations, mesh_verts.shape[0], mesh_faces.shape[0]))

# Ignore

In [None]:
"""
to-dos: volume subdivision

identify T_surf's neighbors: i.e. share a same edge

subdivide T_surf to perform another surface refinement
unsubdivded tet, i.e. not surface's neighbor, is dropped to save memory and computation

"""

#neightbor if share an edge
#identify edge based on surf_faces and adj matrix

"""
convert face_lists = [v1, v2, v3, v4] --> 
[
  [v1, v2, v3, E, o1],
  [v2, v3, v4, E, o2],
  [v1, v2, v4, E, o3],
  [v1, v3, v4, E, o4]
]

by torch combinations, E = tet index, o_i = face_i index
"""
def convert_face_lists(tets):
  face_idx = torch.tensor([1,2,3,4])
  tet_face_list = []
  for idx, tet in enumerate(tets):
    tet_idx = torch.full(4, idx)
    tet_faces = torch.combinations(tet, r=3)
    tet_faces = torch.cat((tet_faces, tet_idx, face_idx), dim=1)
    tet_face_list.append(tet_faces)

  tet_face_list = torch.stack(tet_face_list, dim=0)

  def compare(face_1, face_2):
    o1 = face_1[0]-face_2[0]
    o2 = face_1[1]-face_2[1]
    o3 = face_1[2]-face_2[2]

    if o1 <= 0 or (o1 == 0 and o2 < 0) or (o1 == 0 and o2 == 0 and o3 < 0):
      return -1
    elif o1 == 0 and o2 == 0 and o3 == 0:
      return 0
    else: 
      return 1

  sorted(tet_face_list, cmp=compare)
  return tet_face_list

def get_neighbor(surf_tet_verts, surf_tets, tet_verts, tets):
  tet_face_list = convert_face_lists(tets)
  

surf_idx, surf = extract_tet(tets, pred_sdfs, 0.003)

print(surf_idx)

In [None]:
def get_faces(input):
  c = torch.combinations(torch.arange(input.size(1)), r=3)
  x = input[:,None].expand(-1,len(c),-1).cpu()
  idx = c[None].expand(len(x), -1, -1)
  x = x.gather(dim=2, index=idx)

  return x.view(-1, *x.shape[2:])

In [None]:
# a = get_faces(tets)
# b = get_faces(tets[surf_idx])

# # OPTIMIZE TF OUT OF THIS
# for f in b:
#   mask = a == f


# Inverse Distance Interpolation



```
# This is formatted as code
```

Other idea is to use a Point Cloud Encoder to retrieve a feature vector from the output's activation.

In [None]:
ENCODER_MLP_CONFIG = {
    'input_dim' : 3, # Coordinates of the grid's vertices
    'hidden_dims' : [256, 800, 1600, 1600],
    'output_dim' : 832, # SDF of the vertex input. The other "output" f_v comes from the prior activation layer of dimension 64
}

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

class MLP(torch.nn.Module):
    def __init__(self, config):
        super(MLP, self).__init__()

        self.input_dim = config['input_dim']
        self.hidden_dims  = config['hidden_dims']
        self.output_dim = config['output_dim']

        # Hidden layers
        self.hiddens = nn.ModuleList()
        in_dim = self.input_dim
        for k in range(len(self.hidden_dims)):
            self.hiddens.append(nn.Linear(in_dim, self.hidden_dims[k]))
            in_dim = self.hidden_dims[k]

        # Output layer
        self.output_layer = torch.nn.Linear(self.hidden_dims[-1], self.output_dim)


    def forward(self, x):
        for hidden in self.hiddens :
            x = F.relu(hidden(x))
        output = self.output_layer(x) # No activation (linear) cuz we do regression

        return output, x # Return output + last feature layer vector

In [None]:
# pointnet = MLP(ENCODER_MLP_CONFIG)
# pointnet = pointnet.to(device)

# pc = pc.to(device)

# print(pointnet)
# print('\n\n')
# summary(pointnet, input_size= pc.shape)

# Interpolate feature values on the grid

*We* implement an inverse distance interpolation based on a K-NN algorithm :
Given N known points and their features and a batch of M points with unknown features, the interpolator find the K nearest neigbors in the set N for each point of the batch M. Then it interpolates the features of the batch points M using an inverse distance weighting of the features of the K known neighbour points. 

K can be fine-tuned to balance efficiency and speed.

Implementation can be found here
https://stackoverflow.com/questions/3104781/inverse-distance-weighted-idw-interpolation-with-python

In [None]:
from __future__ import division
import numpy as np
from scipy.spatial import cKDTree as KDTree
    # http://docs.scipy.org/doc/scipy/reference/spatial.html

__date__ = "2010-11-09 Nov"  # weights, doc

#...............................................................................
class Invdisttree:
    """ inverse-distance-weighted interpolation using KDTree:
invdisttree = Invdisttree( X, z )  -- data points, values
interpol = invdisttree( q, nnear=3, eps=0, p=1, weights=None, stat=0 )
    interpolates z from the 3 points nearest each query point q;
    For example, interpol[ a query point q ]
    finds the 3 data points nearest q, at distances d1 d2 d3
    and returns the IDW average of the values z1 z2 z3
        (z1/d1 + z2/d2 + z3/d3)
        / (1/d1 + 1/d2 + 1/d3)
        = .55 z1 + .27 z2 + .18 z3  for distances 1 2 3

    q may be one point, or a batch of points.
    eps: approximate nearest, dist <= (1 + eps) * true nearest
    p: use 1 / distance**p
    weights: optional multipliers for 1 / distance**p, of the same shape as q
    stat: accumulate wsum, wn for average weights

How many nearest neighbors should one take ?
a) start with 8 11 14 .. 28 in 2d 3d 4d .. 10d; see Wendel's formula
b) make 3 runs with nnear= e.g. 6 8 10, and look at the results --
    |interpol 6 - interpol 8| etc., or |f - interpol*| if you have f(q).
    I find that runtimes don't increase much at all with nnear -- ymmv.

p=1, p=2 ?
    p=2 weights nearer points more, farther points less.
    In 2d, the circles around query points have areas ~ distance**2,
    so p=2 is inverse-area weighting. For example,
        (z1/area1 + z2/area2 + z3/area3)
        / (1/area1 + 1/area2 + 1/area3)
        = .74 z1 + .18 z2 + .08 z3  for distances 1 2 3
    Similarly, in 3d, p=3 is inverse-volume weighting.

Scaling:
    if different X coordinates measure different things, Euclidean distance
    can be way off.  For example, if X0 is in the range 0 to 1
    but X1 0 to 1000, the X1 distances will swamp X0;
    rescale the data, i.e. make X0.std() ~= X1.std() .

A nice property of IDW is that it's scale-free around query points:
if I have values z1 z2 z3 from 3 points at distances d1 d2 d3,
the IDW average
    (z1/d1 + z2/d2 + z3/d3)
    / (1/d1 + 1/d2 + 1/d3)
is the same for distances 1 2 3, or 10 20 30 -- only the ratios matter.
In contrast, the commonly-used Gaussian kernel exp( - (distance/h)**2 )
is exceedingly sensitive to distance and to h.

    """
# anykernel( dj / av dj ) is also scale-free
# error analysis, |f(x) - idw(x)| ? todo: regular grid, nnear ndim+1, 2*ndim

    def __init__( self, X, z, leafsize=10, stat=0 ):
        assert len(X) == len(z), "len(X) %d != len(z) %d" % (len(X), len(z))
        self.tree = KDTree( X, leafsize=leafsize )  # build the tree
        self.z = z
        self.stat = stat
        self.wn = 0
        self.wsum = None;

    def __call__( self, q, nnear=6, eps=0, p=1, weights=None ):
            # nnear nearest neighbours of each query point --
        q = np.asarray(q)
        qdim = q.ndim
        if qdim == 1:
            q = np.array([q])
        if self.wsum is None:
            self.wsum = np.zeros(nnear)

        self.distances, self.ix = self.tree.query( q, k=nnear, eps=eps )
        interpol = np.zeros( (len(self.distances),) + np.shape(self.z[0]) )
        jinterpol = 0
        for dist, ix in zip( self.distances, self.ix ):
            if nnear == 1:
                wz = self.z[ix]
            elif dist[0] < 1e-10:
                wz = self.z[ix[0]]
            else:  # weight z s by 1/dist --
                w = 1 / dist**p
                if weights is not None:
                    w *= weights[ix]  # >= 0
                w /= np.sum(w)
                wz = np.dot( w, self.z[ix] )
                if self.stat:
                    self.wn += 1
                    self.wsum += w
            interpol[jinterpol] = wz
            jinterpol += 1
        return interpol if qdim > 1  else interpol[0]

Let's test the interpolation by splitting the point cloud (of size 1000) into a 

1.   List item
2.   List item

known set of size N= 600 and an unknown batch of M= 400 points

In [None]:
INTERP_MLP_CONFIG = {
    'input_dim' : 3, # Coordinates of the grid's vertices
    'hidden_dims' : [256, 800, 1600, 1600],
    'output_dim' : 832, # SDF of the vertex input. The other "output" f_v comes from the prior activation layer of dimension 64
}

pointnet = MLP(INTERP_MLP_CONFIG).to(device)
print(pointnet)
print('\n\n')
summary(pointnet, input_size= gt_verts.shape)

In [None]:
pointnet = MLP(INTERP_MLP_CONFIG).to(device)
pointnet.pre_train_sphere(1000)
sdf_vars = [p for _, p in pointnet.named_parameters()]
sdf_optimizer = torch.optim.Adam(sdf_vars, lr=lr)
sdf_scheduler = torch.optim.lr_scheduler.LambdaLR(sdf_optimizer, lr_lambda=lambda x: max(0.0, 10**(-x*0.0002))) # LR decay over time
# ~12 min. Speed up?
encoder_train(iterations, sdf_model, sdf_optimizer, sdf_scheduler)

pointnet.eval()
f_vol, _ = pointnet(wt_verts)


sample_known = wt_verts.squeeze()[:600] # shape 600 x 3. Set of N= 600 points with known features
sample_unknown = wt_verts.squeeze()[600:] # shape 400x3. Batch of M= 400 points on which we will interpolate the features
f_vol_known = f_vol.squeeze()[:600] # Shape 600 x 832. Set of the feature vectors F_vol of the set N
f_vol_unknown = f_vol.squeeze()[600:] # Shape 400 x 832. Ground truth of F_vol for the batch M unknown points. Used to test the interpolation performance


In [None]:
leafsize = 10 # leaf size of the KDTree. This means the KDTree will store "leafsize" number of neighbour points for each data point
eps = .1  # approximate nearest, dist <= (1 + eps) * true nearest
p = 1  # weights ~ 1 / distance**p
Nnear = 8  # 8 2d, 11 3d => 5 % chance one-sided -- Wendel, mathoverflow.com


invdisttree = Invdisttree( sample_known.squeeze().detach().cpu(), f_vol_known.squeeze().detach().cpu(), leafsize=leafsize, stat=1 )
interpol = invdisttree( sample_unknown.detach().cpu(), nnear=Nnear, eps=eps, p=p ) # return numpy array

# err = np.abs( f_vol_unknown.detach().numpy() - interpol )
# print("average |ground_truth - interpolated|: %.2g" % np.mean(err))

Now we compute the F_vol interpolation on the grid vertices

In [None]:
tets_f_vol = invdisttree( tets_verts.detach().cpu(), nnear=Nnear, eps=eps, p=p ) # ~ 5-10sec
tets_f_vol = torch.tensor(tets_f_vol, device=device, dtype=torch.float)

In [None]:
sdf_input = torch.cat(tensors= (tets_f_vol, tets_verts), dim= 1)

# To Do: Add this to MLP Encoder

In [None]:
SDF_INTERPOLATE_MLP_CONFIG = {
    'input_dim' : 3 + 832, # Coordinates of the grid's vertices + F_vol dimension (concatenation)
    'hidden_dims' : [256, 256, 128, 64],
    'output_dim' : 1, # SDF of the vertex input. The other "output" f_v comes from the prior activation layer of dimension 64
}
sdf_model = MLP(SDF_INTERPOLATE_MLP_CONFIG).to(device)
sdf_model.pre_train_sphere(2000)
sdf_vars = [p for _, p in sdf_model.named_parameters()]
sdf_optimizer = torch.optim.Adam(sdf_vars, lr=lr)
sdf_scheduler = torch.optim.lr_scheduler.LambdaLR(sdf_optimizer, lr_lambda=lambda x: max(0.0, 10**(-x*0.0002))) # LR decay over time
# ~12 min. Speed up?
encoder_train(iterations, sdf_model, sdf_optimizer, sdf_scheduler, tv=sdf_input)

In [None]:
torch.cuda.empty_cache()

# pred_sdfs, f_vs = sdf_model(sdf_input)

# Visualization

In [593]:
#Use pyngrok to access localhost:80 on Colab

!pip install pyngrok --quiet 
from pyngrok import ngrok

# Terminate open tunnels if exist
ngrok.kill()

# Setting the authtoken (optional)
# Get authtoken from https://dashboard.ngrok.com/auth
NGROK_AUTH_TOKEN = "2Hzzzh94FgOXssVkSP5Yffz8uYg_By2RMDZLTPx1aXakhYfH"
ngrok.set_auth_token(NGROK_AUTH_TOKEN)

In [686]:
#generating a public url mapped to localhost 80
public_url = ngrok.connect(port=80, proto="http", options={"bind_tls": True, "local": True})
print("Tracking URL:", public_url)

Tracking URL: NgrokTunnel: "http://7b63-34-141-173-195.ngrok.io" -> "http://localhost:80"


In [687]:
#Start Kaolin Dash3D on localhost:80 
!kaolin-dash3d --logdir=/content/drive/MyDrive/CV_DMTet/Logs --port=80

Dash3D server starting. Go to: http://localhost:80
2022-12-16 00:52:48,131|    INFO|kaolin.visualize.timelapse| No checkpoints found for type voxelgrid: no files matched pattern voxelgrid*.usd in /content/drive/MyDrive/CV_DMTet/Logs
2022-12-16 00:52:53,782|    INFO|kaolin.visualize.timelapse| No checkpoints found for type voxelgrid: no files matched pattern voxelgrid*.usd in /content/drive/MyDrive/CV_DMTet/Logs
2022-12-16 00:52:53,809|    INFO| tornado.access| 200 GET / (127.0.0.1) 213.34ms
2022-12-16 00:52:54,097|    INFO| tornado.access| 200 GET /static/thirdparty.css (127.0.0.1) 2.78ms
2022-12-16 00:52:54,153|    INFO| tornado.access| 200 GET /static/thirdparty.js (127.0.0.1) 4.26ms
2022-12-16 00:52:54,197|    INFO| tornado.access| 200 GET /static/style.css (127.0.0.1) 4.20ms
2022-12-16 00:52:54,200|    INFO| tornado.access| 200 GET /static/core-min.js (127.0.0.1) 2.72ms
2022-12-16 00:52:54,857|    INFO| tornado.access| 200 GET /static/green_plastic.frag (127.0.0.1) 2.53ms
2022-12-1

# Evaluation

In [636]:
baseline = Decoder(multires=2).to(device)
baseline.pre_train_sphere(1000)
baseline_vars = [p for _, p in baseline.named_parameters()]
baseline_optimizer = torch.optim.Adam(baseline_vars, lr=lr)
baseline_scheduler = torch.optim.lr_scheduler.LambdaLR(baseline_optimizer, lr_lambda=lambda x: max(0.0, 10**(-x*0.0002))) # LR decay over time
sdf_train(iterations, baseline, baseline_optimizer, baseline_scheduler)

Initialize SDF to sphere


100%|██████████| 1000/1000 [00:17<00:00, 58.68it/s]


Pre-trained MLP 0.00047968627768568695
Iteration 0 - loss: 1.1691309213638306
Iteration 100 - loss: 0.9695544838905334
Iteration 200 - loss: 0.09632495790719986
Iteration 300 - loss: 0.002511597704142332
Iteration 400 - loss: 0.004841003566980362
Iteration 500 - loss: 0.0031850431114435196
Iteration 600 - loss: 0.004886094946414232
Iteration 700 - loss: 0.006782314274460077
Iteration 800 - loss: 0.0021261628717184067
Iteration 900 - loss: 0.004020349122583866
Iteration 1000 - loss: 0.0036146368365734816
Iteration 1100 - loss: 0.0021590515971183777
Iteration 1200 - loss: 0.00301489420235157
Iteration 1300 - loss: 0.0010439787292852998
Iteration 1400 - loss: 0.00287272478453815
Iteration 1500 - loss: 0.0029820026829838753
Iteration 1600 - loss: 0.002300722524523735
Iteration 1700 - loss: 0.003140554064884782
Iteration 1800 - loss: 0.002698600059375167
Iteration 1900 - loss: 0.0022180448286235332
Iteration 2000 - loss: 0.0020990099292248487
Iteration 2100 - loss: 0.0016078735934570432
Ite

In [637]:
def test(model, points, gt_verts, df, label, modelname):
  # Set to training mode
  model.eval()

  with torch.no_grad():
      F_vol = F.interpolate(points[None, None, None,:, :], size=(1,tets_verts.shape[0], tets_verts.shape[1]), mode='trilinear', align_corners=False)
      F_vol = torch.cat((tets_verts, F_vol.squeeze(0).squeeze(0).squeeze(0).to(device)), dim=1)
      
      pred, _ = model(F_vol)

      pred_sdfs, deform = pred[:,0], pred[:,1:]
      verts_deformed = tets_verts + torch.tanh(deform) / grid_res # constraint deformation to avoid flipping tets
      pred_mesh_verts, pred_mesh_faces = marching_tetrahedra(verts_deformed.unsqueeze(0), tets, pred_sdfs.unsqueeze(0))
      pred_mesh_verts, pred_mesh_faces = pred_mesh_verts[0], pred_mesh_faces[0]

      pred_points = kaolin.ops.mesh.sample_points(pred_mesh_verts.unsqueeze(0), 
                                                pred_mesh_faces, 5000)[0][0]

      chamferL1 = 100*pytorch3d.loss.chamfer_distance(pred_points.unsqueeze(0), gt_verts.unsqueeze(0), norm=1)[0] 
      chamferL2 = 100*kaolin.metrics.pointcloud.chamfer_distance(pred_points.unsqueeze(0), gt_verts.unsqueeze(0), squared=False)[0] 
      fscore = 100*kaolin.metrics.pointcloud.f_score(gt_verts.unsqueeze(0), pred_points.unsqueeze(0))[0]


      print (f'chamfer loss (L1) for {label}: {chamferL1}')
      res = [label, 
             modelname,
             chamferL1.cpu().numpy(), 
             chamferL2.cpu().numpy(), 
             fscore.cpu().numpy()]
      df = df.append({a:v for (a,v) in zip(df.columns, res)}, ignore_index=True)
      # save reconstructed mesh
      timelapse.add_mesh_batch(
          category="results_" +label,
          vertices_list=[pred_mesh_verts.cpu()],
          faces_list=[pred_mesh_faces.cpu()]
      )
  return df

In [638]:
next_shapenet_idx = group[-1]["idx"]
print(next_shapenet_idx)

17


In [709]:
def test_GCN(gcn_model, sdf_model, wt_grid, gt_verts, df, label, modelname):
    # Set to training mode
    gcn_model.eval()
    sdf_model.eval()

    F_vol = F.interpolate(wt_grid.unsqueeze(0), size=(1,tets_verts.shape[0], tets_verts.shape[1]), mode='trilinear', align_corners=False)

    F_vol = torch.cat((tets_verts, F_vol.squeeze(0).squeeze(0).squeeze(0).to(device)), dim=1)

    pred_sdfs, f_vs = sdf_model(F_vol)

    """
    Surface Refinement
    """

    surf_tets_verts_idx, surf_tets_faces = extract_tet(tets, pred_sdfs, 0.008) #if not working modify the 0.008 here; this is the threshold for surface sdf value
    surf_tets_verts = torch.clone(tets_verts[surf_tets_verts_idx])

    surf_tets_verts_features = torch.clone(f_vs[surf_tets_verts_idx])
    surf_sdfs = pred_sdfs[surf_tets_verts_idx]
    surf_tets_edges = torch.clone(get_edges(surf_tets_faces).to(device))
    surf_verts_f = torch.cat((surf_tets_verts, surf_sdfs, surf_tets_verts_features), dim=1)

    sdf, deform, fv = gcn_model(surf_verts_f, surf_tets_edges)
    
    """

    Update surface position, sdf, and f_s

    """

    #updated sdf

    update_sdfs = pred_sdfs.clone()
    update_sdfs[surf_tets_verts_idx] += sdf

    #update vertices positions

    update_tets_verts = tets_verts.clone()
    update_tets_verts[surf_tets_verts_idx] += deform / grid_res

    #update vertices features

    update_tets_f = f_vs.clone()
    update_tets_f[surf_tets_verts_idx] += fv

    mesh_verts, mesh_faces = kaolin.ops.conversions.marching_tetrahedra(update_tets_verts.unsqueeze(0), tets, update_sdfs.squeeze(1).unsqueeze(0))
    mesh_verts, mesh_faces = mesh_verts[0], mesh_faces[0]

    pred_points = kaolin.ops.mesh.sample_points(mesh_verts.unsqueeze(0), 
                                                mesh_faces, 5000)[0][0]

    chamferL1 = 100*pytorch3d.loss.chamfer_distance(pred_points.unsqueeze(0).detach(), gt_verts.unsqueeze(0), norm=1)[0] 
    chamferL2 = 100*kaolin.metrics.pointcloud.chamfer_distance(pred_points.unsqueeze(0).detach(), gt_verts.unsqueeze(0), squared=False)[0] 
    fscore = 100*kaolin.metrics.pointcloud.f_score(gt_verts.unsqueeze(0), pred_points.unsqueeze(0).detach())[0]


    print (f'chamfer loss (L1) for {label}: {chamferL1}')
    res = [label, 
            modelname,
            chamferL1.cpu().numpy(), 
            chamferL2.cpu().numpy(), 
            fscore.cpu().numpy()]
    df = df.append({a:v for (a,v) in zip(df.columns, res)}, ignore_index=True)
    # save reconstructed mesh
    timelapse.add_mesh_batch(
        category="results_" +label,
        vertices_list=[mesh_verts.cpu()],
        faces_list=[mesh_faces.cpu()]
    )
    return df

In [710]:
import pandas as pd

def  create_df(num_test_models, modelname ):
  assert modelname in ("Baseline", "Ours")
  next_shapenet_idx = group[-1]["idx"]
  print(next_shapenet_idx)
  df = pd.DataFrame(columns = ['Label', 'ModelName', 'L1 Chamfer', 'L2 Chamfer', 'F Score'])
  
  idx = next_shapenet_idx
  label=f"bench_{modelname}_{idx}"

  idx, gt_verts, gt_faces, wt_grid, wt_verts, wt_faces, points = get_next_shapenet(idx)
  if modelname == "Baseline":
    df = test(baseline, points, gt_verts, df, label, modelname)
  else:
    df = test_GCN(refine_model,sdf_model, wt_grid, gt_verts, df, label, modelname=modelname)
  
  for i in range(num_test_models - 1):
    idx, gt_verts, gt_faces, wt_grid, wt_verts, wt_faces, points = get_next_shapenet(idx)
    if modelname == "Baseline":
      df = test(baseline, points, gt_verts, df,  label, modelname)
    else:
      df = test_GCN(refine_model,sdf_model, wt_grid, gt_verts, df, label, modelname=modelname)
  
  return df

In [711]:
create_df(4, "Baseline")

17
chamfer loss (L1) for bench_Baseline_17: 20.075788497924805
chamfer loss (L1) for bench_Baseline_17: 19.034822463989258
chamfer loss (L1) for bench_Baseline_17: 8.30785846710205
chamfer loss (L1) for bench_Baseline_17: 19.667226791381836


Unnamed: 0,Label,ModelName,L1 Chamfer,L2 Chamfer,F Score
0,bench_Baseline_17,Baseline,20.075788,15.6275425,3.3848615
1,bench_Baseline_17,Baseline,19.034822,13.452649,0.15147115
2,bench_Baseline_17,Baseline,8.307858,5.875101,6.636428
3,bench_Baseline_17,Baseline,19.667227,14.449131,0.92735654


In [712]:
create_df(4, "Ours")

17
chamfer loss (L1) for bench_Ours_17: 24.275863647460938
chamfer loss (L1) for bench_Ours_17: 22.61875343322754
chamfer loss (L1) for bench_Ours_17: 14.254079818725586
chamfer loss (L1) for bench_Ours_17: 19.166410446166992


Unnamed: 0,Label,ModelName,L1 Chamfer,L2 Chamfer,F Score
0,bench_Ours_17,Ours,24.275864,18.709873,2.6170237
1,bench_Ours_17,Ours,22.618753,15.847124,0.17573375
2,bench_Ours_17,Ours,14.25408,10.072594,2.519173
3,bench_Ours_17,Ours,19.16641,14.010626,1.1579027
