<a href="https://colab.research.google.com/github/Aydin-ab/CV_DMTet/blob/main/CV_DMTet_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Installing Dependencies: Kaolin and PyTorch3D

In [None]:
!pip install torch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113 --quiet

[K     |██████████████▌                 | 834.1 MB 1.2 MB/s eta 0:13:31tcmalloc: large alloc 1147494400 bytes == 0x38e02000 @  0x7f1128339615 0x5d6f4c 0x51edd1 0x51ef5b 0x4f750a 0x4997a2 0x4fd8b5 0x4997c7 0x4fd8b5 0x49abe4 0x4f5fe9 0x55e146 0x4f5fe9 0x55e146 0x4f5fe9 0x55e146 0x5d8868 0x5da092 0x587116 0x5d8d8c 0x55dc1e 0x55cd91 0x5d8941 0x49abe4 0x55cd91 0x5d8941 0x4990ca 0x5d8868 0x4997a2 0x4fd8b5 0x49abe4
[K     |██████████████████▍             | 1055.7 MB 1.2 MB/s eta 0:10:51tcmalloc: large alloc 1434370048 bytes == 0x7d458000 @  0x7f1128339615 0x5d6f4c 0x51edd1 0x51ef5b 0x4f750a 0x4997a2 0x4fd8b5 0x4997c7 0x4fd8b5 0x49abe4 0x4f5fe9 0x55e146 0x4f5fe9 0x55e146 0x4f5fe9 0x55e146 0x5d8868 0x5da092 0x587116 0x5d8d8c 0x55dc1e 0x55cd91 0x5d8941 0x49abe4 0x55cd91 0x5d8941 0x4990ca 0x5d8868 0x4997a2 0x4fd8b5 0x49abe4
[K     |███████████████████████▎        | 1336.2 MB 1.2 MB/s eta 0:06:49tcmalloc: large alloc 1792966656 bytes == 0x228a000 @  0x7f1128339615 0x5d6f4c 0x51edd1 0x51ef5b 0x4

In [None]:
from google.colab import drive
drive.mount('/content/drive')
from pathlib import Path
INSTALL_PATH = Path("/content/drive/'MyDrive'/CV_DMTet/")
%cd  $INSTALL_PATH

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/CV_DMTet


In [None]:
# reinstall cython, install usd-core (for 3D rendering), and clone into kaolin repo
!pip uninstall Cython --yes
!pip install Cython==0.29.20  --quiet
!pip install usd-core --quiet

In [None]:
# installing kaolin and check version
%env IGNORE_TORCH_VER=1
%env KAOLIN_INSTALL_EXPERIMENTAL=1
KAOLIN_PATH = INSTALL_PATH / "kaolin"
%cd $INSTALL_PATH
!if [ ! -d $KAOLIN_PATH ]; then git clone --recursive https://github.com/NVIDIAGameWorks/kaolin; fi;
%cd $KAOLIN_PATH
SETUP_CHECK = KAOLIN_PATH / "kaolin" / "version.py"
# !echo Checking if $SETUP_CHECK exists
# !if [ ! -f $SETUP_CHECK ]; then python setup.py develop; fi;
!python setup.py develop
!python -c "import kaolin; print(kaolin.__version__)"

# Import packages

In [None]:
import numpy as np
import torch
import kaolin
import sys
import os

need_pytorch3d=False
try:
    import pytorch3d
except ModuleNotFoundError:
    need_pytorch3d=True
if need_pytorch3d:
    if torch.__version__.startswith("1.12.") and sys.platform.startswith("linux"):
        # We try to install PyTorch3D via a released wheel.
        pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
        version_str="".join([
            f"py3{sys.version_info.minor}_cu",
            torch.version.cuda.replace(".",""),
            f"_pyt{pyt_version_str}"
        ])
        !pip install fvcore iopath
        !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
    else:
        # We try to install PyTorch3D from source.
        !curl -LO https://github.com/NVIDIA/cub/archive/1.10.0.tar.gz
        !tar xzf 1.10.0.tar.gz
        os.environ["CUB_HOME"] = os.getcwd() + "/cub-1.10.0"
        !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'
    import pytorch3d

from kaolin.ops.conversions import (
    trianglemeshes_to_voxelgrids,
    marching_tetrahedra,
    voxelgrids_to_cubic_meshes,
    voxelgrids_to_trianglemeshes,
)

from kaolin.ops.mesh import (
    index_vertices_by_faces
)

from kaolin.io.shapenet import (
    ShapeNetV2
)

from kaolin.metrics.trianglemesh import (
    point_to_mesh_distance,

)

from pytorch3d.datasets import (
    ShapeNetCore,
    collate_batched_meshes
)


from torch.utils.data import DataLoader

# add path for demo utils functions 
sys.path.append(os.path.abspath(''))
sys.path.append('/content/drive/MyDrive/CV_DMTet/')
sys.path.append('/content/drive/MyDrive/CV_DMTet/pvcnn')

torch.manual_seed(3407)
torch.cuda.manual_seed(3407)

# Import Dataset: Subset of ShapeNetV2

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")
device

device(type='cuda', index=0)

In [None]:
# path to the output logs (readable with the training visualizer in the omniverse app)
logs_path = '/content/drive/MyDrive/CV_DMTet/Logs'

# We initialize the timelapse that will store USD for the visualization apps
timelapse = kaolin.visualize.Timelapse(logs_path)

In [None]:
# import by pytorch3d

SHAPENET_PATH = "/content/drive/MyDrive/CV_DMTet/Core"
#SHAPENET_PATH = "/content/drive/MyDrive/FALL 2022/Computer Vision/Project/Core"
SYNSETS_IDS = ['02747177', '02773838', '02801938', '02808440', '02818832', '02828884', '02843684', '02871439', '02876657', '02880940', '02924116', '02933112']
# shapenet_dataset = ShapeNetCore(SHAPENET_PATH, synsets=SYNSETS_IDS, version=2)
# shapenet_loader = DataLoader(shapenet_dataset, batch_size=3, collate_fn=collate_batched_meshes)

In [None]:
# import by kaolin

shapenet_train = ShapeNetV2(SHAPENET_PATH, categories=SYNSETS_IDS, output_dict=True)
shapenet_test = ShapeNetV2(SHAPENET_PATH, categories=SYNSETS_IDS, output_dict=True, train=False)

#Model visualization

In [None]:
sample_model = shapenet_train[560] # change the index here for different models
sample_verts = sample_model['mesh'][0]
sample_faces = sample_model['mesh'][1]

center = (sample_verts.max(0)[0] + sample_verts.min(0)[0]) / 2
max_l = (sample_verts.max(0)[0] - sample_verts.min(0)[0]).max()
sample_verts = ((sample_verts - center) / max_l)

timelapse.add_mesh_batch(
    category='gt',
    vertices_list=[sample_verts.cpu()],
    faces_list=[sample_faces.cpu()]
)

# Convert model to watertight meshes

We used a voxelization with resolution of 64 to predict the sdf and extract surface. 

In [None]:
wt_grid = kaolin.ops.conversions.trianglemeshes_to_voxelgrids(
    vertices=sample_verts.unsqueeze(0).to(device),
    faces=sample_faces.to(device),
    resolution=64
)

In [None]:
wt_verts, wt_faces = kaolin.ops.conversions.voxelgrids_to_cubic_meshes(wt_grid)
wt_verts, wt_faces = wt_verts[0], wt_faces[0]

In [None]:
center = (wt_verts.max(0)[0] + wt_verts.min(0)[0]) / 2
max_l = (wt_verts.max(0)[0] - wt_verts.min(0)[0]).max()
wt_verts = ((wt_verts - center) / max_l)* 0.8
timelapse.add_mesh_batch(
    category='watertight_test',
    vertices_list=[wt_verts.cpu()],
    faces_list=[wt_faces.cpu()]
)

#Load Tetrahedral grid

DMTet starts from a uniform tetrahedral grid of predefined resolution, and uses a network to predict the SDF value as well as deviation vector at each grid vertex.

Here we load the pre-generated tetrahedral grid using Quartet at resolution 128, which has roughly the same number of vertices as a voxel grid of resolution 65. We use a simple MLP + positional encoding to predict the SDF and deviation vectors in DMTet, and initialize the encoded SDF to represent a sphere.

In [None]:
# Uniform Tetrahedral Grid
tets_verts = torch.tensor(np.load('/content/drive/MyDrive/CV_DMTet/kaolin/examples/samples/128_verts.npz')['data'], dtype=torch.float, device=device)
tets = torch.tensor(([np.load('/content/drive/MyDrive/CV_DMTet/kaolin/examples/samples/128_tets_{}.npz'.format(i))['data'] for i in range(4)]), dtype=torch.long, device=device).permute(1,0)
print(tets_verts, tets)

tensor([[ 0.5000,  0.5000,  0.4844],
        [ 0.4844,  0.5000,  0.4922],
        [ 0.4922,  0.4844,  0.4844],
        ...,
        [-0.1719, -0.5000,  0.4766],
        [-0.1562, -0.5000,  0.4688],
        [-0.1562, -0.4922,  0.4688]], device='cuda:0') tensor([[     0,      1,      2,      3],
        [     2,      3,      1,      4],
        [     5,      3,      0,      2],
        ...,
        [277409, 272920, 272914, 272919],
        [272919, 277409, 272920, 274866],
        [277409, 277400, 272920, 274866]], device='cuda:0')


  tets = torch.tensor(([np.load('/content/drive/MyDrive/CV_DMTet/kaolin/examples/samples/128_tets_{}.npz'.format(i))['data'] for i in range(4)]), dtype=torch.long, device=device).permute(1,0)


# SDF Model & GCN Model

We follow the paper recommandation and use a four-layer
MLPs with hidden dimensions 256, 256, 128 and 64, respectively

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

class MLP(torch.nn.Module):
    def __init__(self, config):
        super(MLP, self).__init__()

        self.input_dim = config['input_dim']
        self.hidden_dims  = config['hidden_dims']
        self.output_dim = config['output_dim']
        self.multires = config['multires']

        self.embed_fn = None
        # if self.multires > 0:
        #     embed_fn, input_ch = get_embedder(self.multires)
        #     self.embed_fn = embed_fn
        #     self.input_dim = input_ch

        # Hidden layers
        self.hiddens = nn.ModuleList()
        in_dim = self.input_dim
        for k in range(len(self.hidden_dims)):
            self.hiddens.append(nn.Linear(in_dim, self.hidden_dims[k]))
            in_dim = self.hidden_dims[k]

        # Output layer
        self.output_layer = torch.nn.Linear(self.hidden_dims[-1], self.output_dim)


    def forward(self, x):
        if self.embed_fn is not None:
            x = self.embed_fn(x)
        for hidden in self.hiddens :
            x = F.relu(hidden(x))
        output = self.output_layer(x) # No activation (linear) cuz we do regression

        return output, x # Return sdf predicted + f_v feature vector


# Positional Encoding from https://github.com/yenchenlin/nerf-pytorch/blob/1f064835d2cca26e4df2d7d130daa39a8cee1795/run_nerf_helpers.py
class Embedder:
    def __init__(self, **kwargs):
        self.kwargs = kwargs
        self.create_embedding_fn()
        
    def create_embedding_fn(self):
        embed_fns = []
        d = self.kwargs['input_dims']
        out_dim = 0
        if self.kwargs['include_input']:
            embed_fns.append(lambda x : x)
            out_dim += d
            
        max_freq = self.kwargs['max_freq_log2']
        N_freqs = self.kwargs['num_freqs']
        
        if self.kwargs['log_sampling']:
            freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
        else:
            freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
            
        for freq in freq_bands:
            for p_fn in self.kwargs['periodic_fns']:
                embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
                out_dim += d
                    
        self.embed_fns = embed_fns
        self.out_dim = out_dim
        
    def embed(self, inputs):
        return torch.cat([fn(inputs) for fn in self.embed_fns], -1)

def get_embedder(multires):
    embed_kwargs = {
                'include_input' : True,
                'input_dims' : 3,
                'max_freq_log2' : multires-1,
                'num_freqs' : multires,
                'log_sampling' : True,
                'periodic_fns' : [torch.sin, torch.cos],
    }
    
    embedder_obj = Embedder(**embed_kwargs)
    embed = lambda x, eo=embedder_obj : eo.embed(x)
    return embed, embedder_obj.out_dim

In [None]:
"""
Graph-res net:
Identify surface tetrahedral, build adj matrix, 
"""

from tqdm import tqdm

from torch import nn
from torch.nn import functional as F
from pytorch3d.ops import GraphConv

# a single res block layer with dimension 256 & 128
class GResBlock(nn.Module): 
    def __init__(self, in_dim, hidden_dim, activation=None):
        super(GResBlock, self).__init__()

        self.conv1 = GraphConv(in_dim, hidden_dim)
        self.conv2 = GraphConv(hidden_dim, in_dim)
        self.activation = F.relu if activation else None
    
    def forward(self, inputs):
        input, adj = inputs[0], inputs[1]
        x = self.conv1(input, adj)
        if self.activation:
          x = self.activation(x)
        x = self.conv2(x, adj)
        if self.activation:
          x = self.activation(input + x)
        
        return [x, adj]

class GBottleneck(nn.Module):
    def __init__(self, block_num, in_dim, hidden_dim, out_dim, activation=None):
        super(GBottleneck, self).__init__()

        resblock_layers = [GResBlock(in_dim=hidden_dim[0], hidden_dim=hidden_dim[1], activation=activation)
                          for _ in range(block_num)]
        self.blocks = nn.Sequential(*resblock_layers)
        self.conv1 = GraphConv(in_dim, hidden_dim[0])

        self.activation = F.relu if activation else None
    
    def forward(self, inputs, adj):
        x = self.conv1(inputs, adj)
        if self.activation:
          x = self.activation(x)
        x = self.blocks([x, adj])[0]
        if self.activation:
          x = self.activation(x)

        return x

class GCN_Res(nn.Module):
    def __init__(self, config):
        super(GCN_Res, self).__init__()

        self.in_dim = config['in_dim']
        self.hidden_dim = config['hidden_dim']
        self.out_dim = config['out_dim']
        self.activation = config['activation']
        self.mlp_hdim = config['mlp_hdim']
        self.mlp_odim = config['mlp_odim']

        self.gcn_res = nn.ModuleList([GBottleneck(2, self.in_dim, self.hidden_dim, self.out_dim, self.activation)])

        self.sdf_mlp = nn.Sequential(
            nn.Linear(self.out_dim, self.mlp_hdim[0], bias=False),
            nn.Linear(self.mlp_hdim[0], self.mlp_hdim[1], bias=False),
            nn.Linear(self.mlp_hdim[1], 1, bias=False),
        )

        self.deform_mlp = nn.Sequential(
            nn.Linear(self.out_dim, self.mlp_hdim[0], bias=False),
            nn.Linear(self.mlp_hdim[0], self.mlp_hdim[1], bias=False),
            nn.Linear(self.mlp_hdim[1], 3, bias=False),
        )

        self.feature_mlp = nn.Sequential(
            nn.Linear(self.out_dim, self.mlp_hdim[0], bias=False),
            nn.Linear(self.mlp_hdim[0], self.mlp_hdim[1], bias=False),
            nn.Linear(self.mlp_hdim[1], self.mlp_hdim[1], bias=False),
        )
        


    def forward(self, inputs, adj):

        x = self.gcn_res[0](inputs, adj)

        sdf = self.sdf_mlp(x)
        deform = self.deform_mlp(x)
        deform = torch.tanh(deform)
        feature = self.feature_mlp(x)
        
        return sdf, deform, feature

 # Surface refinement utils

In [None]:
# Get edges lists of shape [E, 2] from face list of shape [V, 4]

def get_edges(input):
  c = torch.combinations(torch.arange(input.size(1)), r=2)
  x = input[:,None].expand(-1,len(c),-1).cpu()
  idx = c[None].expand(len(x), -1, -1)
  x = x.gather(dim=2, index=idx)

  return x.view(-1, *x.shape[2:])

# Extract tets under certain sdf restrictions:
# if thresh = 0, return all surface tetrahedrons
# if thresh > 0, return all tetrahedrons whose vertices' sdfs are all in the range [-thresh, thresh]

def extract_tet(tets, sdf, thresh, non_surf=False):

  assert thresh >= 0

  if thresh == 0:
    mask = sdf[tets] > 0
    mask_int = mask.squeeze(2).long()
    t = mask_int.sum(1)
    surf_tets = tets[(t > 0) & (t < 4)]
  else:
    mask = (sdf[tets] >= -thresh) & (sdf[tets] <= thresh)
    mask_int = mask.squeeze(2).long()
    t = mask_int.sum(1)
    surf_tets = tets[t == 4]

  surf_tets_tuple = surf_tets.unique(return_inverse=True)
  surf_tets_idx, surf_tets = surf_tets_tuple[0], surf_tets_tuple[1]

  if non_surf:
    if thresh == 0:
      non_surf_tets = tets[~((t > 0) & (t < 4))]
    else:
      non_surf_tets = tets[t < 4]
    non_surf_tets_tuple = non_surf_tets.unique(return_inverse=True)
    non_surf_tets_idx, non_surf_tets = non_surf_tets_tuple[0], non_surf_tets_tuple[1]
    return surf_tets_idx, surf_tets, non_surf_tets_idx, non_surf_tets

  return surf_tets_idx, surf_tets

#Get output from initial MLP

def get_pred_sdfs(model, tets_verts, gt_verts):
  model.eval()

  with torch.no_grad():
    F_vol = F.interpolate(gt_verts[None, None, None,:, :], size=(1,tets_verts.shape[0], tets_verts.shape[1]), mode='trilinear', align_corners=False)

    F_vol = torch.cat((tets_verts, F_vol.squeeze(0).squeeze(0).squeeze(0).to(device)), dim=1)
    pred_sdfs, f_vs = model(F_vol)

  return pred_sdfs, f_vs

# Get ground truth sdf from input verts and faces
# input shape: [batch_size, num_vertices, 3], [num_faces, 4], [batch_size, num_points, 3]
# output shape: [num_points, 1]

def get_gt_sdfs(gt_verts, gt_faces, points, f=None):

  if f == None:
    f = index_vertices_by_faces(gt_verts, gt_faces)
  d,_,_ = point_to_mesh_distance(points, f)

  s = kaolin.ops.mesh.check_sign(gt_verts, gt_faces, points)
  d[s==True] *= -1

  return d.squeeze(0).unsqueeze(1)

#Instanciate Models

In [None]:
# Since we skip PVCNN, input dimension is just the coordinates of each grid

SDF_MLP_CONFIG = {
    'input_dim' : 6, # Coordinates of the grid's vertices #previously 3
    'hidden_dims' : [256, 256, 128, 64],
    'output_dim' : 1, # SDF of the vertex input. The other "output" f_v comes from the prior activation layer of dimension 64
    'multires': 2
}

lr = 1e-4
laplacian_weight = 0.1
save_every = 100
multires = 2
grid_res = 128
epoch = 1000

In [None]:
from torchsummary import summary

sdf_model = MLP(SDF_MLP_CONFIG).to(device)
print(sdf_model)
print('\n\n')
summary(sdf_model, input_size= (tets_verts.shape[0], 6))

MLP(
  (hiddens): ModuleList(
    (0): Linear(in_features=6, out_features=256, bias=True)
    (1): Linear(in_features=256, out_features=256, bias=True)
    (2): Linear(in_features=256, out_features=128, bias=True)
    (3): Linear(in_features=128, out_features=64, bias=True)
  )
  (output_layer): Linear(in_features=64, out_features=1, bias=True)
)



----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1          [-1, 277410, 256]           1,792
            Linear-2          [-1, 277410, 256]          65,792
            Linear-3          [-1, 277410, 128]          32,896
            Linear-4           [-1, 277410, 64]           8,256
            Linear-5            [-1, 277410, 1]              65
Total params: 108,801
Trainable params: 108,801
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 6.35
Forward/backward pass size (MB): 1492.11

In [None]:
CONFIG_GCNRES = {
    'in_dim': 68,
    'hidden_dim': [128, 256],
    'out_dim': 128,
    'activation': True,
    'mlp_hdim': [128,64],
    'mlp_odim': 68,
}

In [None]:
refine_model = GCN_Res(CONFIG_GCNRES).to(device)

print(refine_model)

GCN_Res(
  (gcn_res): ModuleList(
    (0): GBottleneck(
      (blocks): Sequential(
        (0): GResBlock(
          (conv1): GraphConv(128 -> 256, directed=False)
          (conv2): GraphConv(256 -> 128, directed=False)
        )
        (1): GResBlock(
          (conv1): GraphConv(128 -> 256, directed=False)
          (conv2): GraphConv(256 -> 128, directed=False)
        )
      )
      (conv1): GraphConv(68 -> 128, directed=False)
    )
  )
  (sdf_mlp): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=False)
    (1): Linear(in_features=128, out_features=64, bias=False)
    (2): Linear(in_features=64, out_features=1, bias=False)
  )
  (deform_mlp): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=False)
    (1): Linear(in_features=128, out_features=64, bias=False)
    (2): Linear(in_features=64, out_features=3, bias=False)
  )
  (feature_mlp): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=False)
    (1): Linear(in_features

#Set Up Optimizer

In [None]:
sdf_vars = [p for _, p in sdf_model.named_parameters()]
sdf_optimizer = torch.optim.Adam(sdf_vars, lr=lr)
sdf_scheduler = torch.optim.lr_scheduler.LambdaLR(sdf_optimizer, lr_lambda=lambda x: max(0.0, 10**(-x*0.0002))) # LR decay over time

params = list(sdf_model.named_parameters()) + list(refine_model.named_parameters())

refine_vars = [p for _, p in params]
refine_optimizer = torch.optim.Adam(refine_vars, lr=lr)
refine_scheduler = torch.optim.lr_scheduler.LambdaLR(refine_optimizer, lr_lambda=lambda x: max(0.0, 10**(-x*0.0002))) # LR decay over time

# Loss Function

In [None]:
from pytorch3d.loss import(
    chamfer_distance
)

def laplace_regularizer_const(mesh_verts, mesh_faces):
    term = torch.zeros_like(mesh_verts)
    norm = torch.zeros_like(mesh_verts[..., 0:1])

    v0 = mesh_verts[mesh_faces[:, 0], :]
    v1 = mesh_verts[mesh_faces[:, 1], :]
    v2 = mesh_verts[mesh_faces[:, 2], :]

    term.scatter_add_(0, mesh_faces[:, 0:1].repeat(1,3), (v1 - v0) + (v2 - v0))
    term.scatter_add_(0, mesh_faces[:, 1:2].repeat(1,3), (v0 - v1) + (v2 - v1))
    term.scatter_add_(0, mesh_faces[:, 2:3].repeat(1,3), (v0 - v2) + (v1 - v2))

    two = torch.ones_like(v0) * 2.0
    norm.scatter_add_(0, mesh_faces[:, 0:1], two)
    norm.scatter_add_(0, mesh_faces[:, 1:2], two)
    norm.scatter_add_(0, mesh_faces[:, 2:3], two)

    term = term / torch.clamp(norm, min=1.0)

    return torch.mean(term**2)

def gcn_loss(iterations, mesh_verts, mesh_faces, gt_verts, gt_faces, it):

    #surface alignment loss
    pm = pytorch3d.structures.Meshes([mesh_verts], [mesh_faces])
    gm = pytorch3d.structures.Meshes([gt_verts], [gt_faces])
  
    if mesh_verts.shape[0] > 0:
      pred_points = kaolin.ops.mesh.sample_points(mesh_verts.unsqueeze(0), mesh_faces, 100000)[0][0]
      gt_points = kaolin.ops.mesh.sample_points(gt_verts.unsqueeze(0), gt_faces, 100000)[0][0]
      chamfer = kaolin.metrics.pointcloud.chamfer_distance(pred_points.unsqueeze(0), gt_points.unsqueeze(0), squared=False).mean()
    else:
      chamfer = 0


    if it > iterations//2:
      lap = laplace_regularizer_const(mesh_verts, mesh_faces)
      return 500*chamfer + lap * laplacian_weight 
    return 500*chamfer

#Train

In [None]:
def test_train(iterations, sdf_model, gcn_model, optimizer, scheduler, epoch):
  gcn_model.train()
  sdf_model.train()
  avg_loss = 0
  wt_models = []

  for i in range(iterations):

    gt_model = shapenet_train[25] # change the index here for different models

    gt_verts = gt_model['mesh'][0]
    gt_faces = gt_model['mesh'][1]

    gt_verts = ((gt_verts - ((gt_verts.max(0)[0] + gt_verts.min(0)[0]) / 2)) / ((gt_verts.max(0)[0] - gt_verts.min(0)[0]).max()))* 0.8

    gt_verts, gt_faces = gt_verts.to(device), gt_faces.to(device)

    wt_grid = kaolin.ops.conversions.trianglemeshes_to_voxelgrids(
        vertices=gt_verts.unsqueeze(0),
        faces=gt_faces,
        resolution=64
    )

    wt_verts, wt_faces = kaolin.ops.conversions.voxelgrids_to_cubic_meshes(wt_grid)
    wt_verts, wt_faces = wt_verts[0], wt_faces[0]

    center = (wt_verts.max(0)[0] + wt_verts.min(0)[0]) / 2
    max_l = (wt_verts.max(0)[0] - wt_verts.min(0)[0]).max()
    wt_verts = ((wt_verts - center) / max_l)* 0.8

    F_vol = F.interpolate(wt_grid.unsqueeze(0), size=(1,tets_verts.shape[0], tets_verts.shape[1]), mode='trilinear', align_corners=False)

    F_vol = torch.cat((tets_verts, F_vol.squeeze(0).squeeze(0).squeeze(0).to(device)), dim=1)

    pred_sdfs, f_vs = sdf_model(F_vol)

    """
    Surface Refinement
    """

    surf_tets_verts_idx, surf_tets_faces = extract_tet(tets, pred_sdfs, 0.008) #if not working modify the 0.008 here; this is the threshold for surface sdf value
    surf_tets_verts = torch.clone(tets_verts[surf_tets_verts_idx])

    surf_tets_verts_features = torch.clone(f_vs[surf_tets_verts_idx])
    surf_sdfs = pred_sdfs[surf_tets_verts_idx]
    surf_tets_edges = torch.clone(get_edges(surf_tets_faces).to(device))
    surf_verts_f = torch.cat((surf_tets_verts, surf_sdfs, surf_tets_verts_features), dim=1)

    sdf, deform, fv = gcn_model(surf_verts_f, surf_tets_edges)
    
    """

    Update surface position, sdf, and f_s

    """

    #updated sdf

    update_sdfs = pred_sdfs.clone()
    update_sdfs[surf_tets_verts_idx] += sdf

    #update vertices positions

    update_tets_verts = tets_verts.clone()
    update_tets_verts[surf_tets_verts_idx] += deform / grid_res

    #update vertices features

    update_tets_f = f_vs.clone()
    update_tets_f[surf_tets_verts_idx] += fv

    if epoch < 500:
      gt_sdfs = get_gt_sdfs(wt_verts.unsqueeze(0), wt_faces, update_tets_verts.unsqueeze(0))

      sdf_loss = F.mse_loss(update_sdfs, gt_sdfs, reduction='mean')

      optimizer.zero_grad()
      sdf_loss.backward(retain_graph=True)
      optimizer.step()
      avg_loss += sdf_loss.item()

      if epoch == 0 and i == 0:
        print('========== Start pretraining ==========')
      
      if i == (iterations - 1) and epoch % 100 == 0:
        print ('Epoch {} - loss: {}'.format(epoch, avg_loss/iterations))

      continue


    """

    Marching Tetrahedra based on new sdf value and deformed vertices in the tet grid

    """

    mesh_verts, mesh_faces = kaolin.ops.conversions.marching_tetrahedra(update_tets_verts.unsqueeze(0), tets, update_sdfs.squeeze(1).unsqueeze(0))
    mesh_verts, mesh_faces = mesh_verts[0], mesh_faces[0]

    """

    Compute Loss for First surface refinement: 
    Normal consistency + surface alignment + laplacian smooth + sdf L2-reg + deform L2-reg

    """

    # L2 sdf reg: 

    s_sdfs = get_gt_sdfs(gt_verts.unsqueeze(0), gt_faces, update_tets_verts.unsqueeze(0))

    mask = ((s_sdfs >= -0.3) & (s_sdfs <= 0.3)).squeeze(1)
    p = update_sdfs[mask]
    g = s_sdfs[mask]

    sdf_loss = F.mse_loss(p, g, reduction='mean') 

    #L2 deform reg

    deform_loss = F.mse_loss(update_tets_verts, tets_verts, reduction='mean')

    #surface alignment loss

    r_loss = gcn_loss(iterations, mesh_verts, mesh_faces, gt_verts, gt_faces, i)

    g_loss = r_loss + deform_loss + 0.4*sdf_loss

    optimizer.zero_grad()
    torch.autograd.set_detect_anomaly(True)
    g_loss.backward(retain_graph=True)
    avg_loss += g_loss.item()
    optimizer.step()
    scheduler.step()

    if epoch == 500 and i == 0:
        print('========== Start Refinement ==========')

    
    if epoch % 100 == 0:
      if (i) % 1 == 0: 
        # print ('Iteration {} - loss: {}, # of mesh vertices: {}, # of mesh faces: {}'.format(i, g_loss, mesh_verts.shape[0], mesh_faces.shape[0]))
        
        # save reconstructed mesh
        timelapse.add_mesh_batch(
            iteration=epoch+1,
            category='final_train_res',
            vertices_list=[mesh_verts.cpu()],
            faces_list=[mesh_faces.cpu()]
        )
      
      if i == (iterations - 1):
        print ('Epoch {} - loss: {}, # of mesh vertices: {}, # of mesh faces: {}'.format(epoch, avg_loss/iterations, mesh_verts.shape[0], mesh_faces.shape[0]))

In [None]:
v_model = shapenet_train[20] #change the index here to match training example
v_verts = v_model['mesh'][0].to(device)
v_faces = v_model['mesh'][1]

timelapse.add_mesh_batch(
      category='gt',
      vertices_list=[v_verts.cpu()],
      faces_list=[v_faces.cpu()]
)


tensor([[0.5700],
        [0.5689],
        [0.5683],
        ...,
        [0.5367],
        [0.5346],
        [0.5344]], device='cuda:0')
torch.Size([6301, 3])
torch.Size([20778, 3])


In [None]:
for i in range(3000):
  test_train(1, sdf_model, refine_model, refine_optimizer, refine_scheduler, i)

Epoch 0 - loss: 0.007392390631139278
Epoch 100 - loss: 6.907405853271484
Epoch 200 - loss: 0.11011458933353424
Epoch 300 - loss: 0.04961121454834938
Epoch 400 - loss: 0.03220337629318237
Epoch 500 - loss: 96.9729232788086, # of mesh vertices: 257625, # of mesh faces: 511583
Epoch 600 - loss: 58.180747985839844, # of mesh vertices: 34787, # of mesh faces: 68934
Epoch 700 - loss: 31.18280029296875, # of mesh vertices: 40110, # of mesh faces: 79868
Epoch 800 - loss: 34.707794189453125, # of mesh vertices: 29231, # of mesh faces: 58152
Epoch 900 - loss: 26.965675354003906, # of mesh vertices: 21904, # of mesh faces: 43760
Epoch 1000 - loss: 23.684951782226562, # of mesh vertices: 42318, # of mesh faces: 85150
Epoch 1100 - loss: 7.034949779510498, # of mesh vertices: 22558, # of mesh faces: 45124
Epoch 1200 - loss: 7.041993141174316, # of mesh vertices: 34202, # of mesh faces: 68928
Epoch 1300 - loss: 4.535536289215088, # of mesh vertices: 35340, # of mesh faces: 71792
Epoch 1400 - loss: 6.

In [None]:
from google.colab import runtime
runtime.unassign()

# Visualization

In [None]:
#Use pyngrok to access localhost:80 on Colab

!pip install pyngrok --quiet 
from pyngrok import ngrok

# Terminate open tunnels if exist
ngrok.kill()

# Setting the authtoken (optional)
# Get authtoken from https://dashboard.ngrok.com/auth
NGROK_AUTH_TOKEN = "2IGctyaa9n7vRBd8qq7pzd0bNKh_2pDFmKRk5Af1QDq295xZ4"
ngrok.set_auth_token(NGROK_AUTH_TOKEN)

[?25l[K     |▍                               | 10 kB 35.4 MB/s eta 0:00:01[K     |▉                               | 20 kB 9.5 MB/s eta 0:00:01[K     |█▎                              | 30 kB 13.4 MB/s eta 0:00:01[K     |█▊                              | 40 kB 6.5 MB/s eta 0:00:01[K     |██▏                             | 51 kB 6.1 MB/s eta 0:00:01[K     |██▋                             | 61 kB 7.2 MB/s eta 0:00:01[K     |███                             | 71 kB 7.7 MB/s eta 0:00:01[K     |███▍                            | 81 kB 6.1 MB/s eta 0:00:01[K     |███▉                            | 92 kB 6.8 MB/s eta 0:00:01[K     |████▎                           | 102 kB 6.6 MB/s eta 0:00:01[K     |████▊                           | 112 kB 6.6 MB/s eta 0:00:01[K     |█████▏                          | 122 kB 6.6 MB/s eta 0:00:01[K     |█████▋                          | 133 kB 6.6 MB/s eta 0:00:01[K     |██████                          | 143 kB 6.6 MB/s eta 0:00:01[K   

In [None]:
#generating a public url mapped to localhost 80
public_url = ngrok.connect(port=80, proto="http", options={"bind_tls": True, "local": True})
print("Tracking URL:", public_url)

Tracking URL: NgrokTunnel: "http://06f4-35-247-79-87.ngrok.io" -> "http://localhost:80"


In [None]:
#Start Kaolin Dash3D on localhost:80 
!kaolin-dash3d --logdir=/content/drive/MyDrive/CV_DMTet/Logs --port=80

Dash3D server starting. Go to: http://localhost:80
2022-12-15 05:04:46,497|    INFO|kaolin.visualize.timelapse| No checkpoints found for type pointcloud: no files matched pattern pointcloud*.usd in /content/drive/MyDrive/CV_DMTet/Logs
2022-12-15 05:04:46,508|    INFO|kaolin.visualize.timelapse| No checkpoints found for type voxelgrid: no files matched pattern voxelgrid*.usd in /content/drive/MyDrive/CV_DMTet/Logs
2022-12-15 05:05:02,553|    INFO|kaolin.visualize.timelapse| No checkpoints found for type pointcloud: no files matched pattern pointcloud*.usd in /content/drive/MyDrive/CV_DMTet/Logs
2022-12-15 05:05:02,566|    INFO|kaolin.visualize.timelapse| No checkpoints found for type voxelgrid: no files matched pattern voxelgrid*.usd in /content/drive/MyDrive/CV_DMTet/Logs
2022-12-15 05:05:02,589|    INFO| tornado.access| 200 GET / (127.0.0.1) 65.04ms
2022-12-15 05:05:02,751|    INFO| tornado.access| 200 GET /static/thirdparty.css (127.0.0.1) 4.01ms
2022-12-15 05:05:02,783|    INFO| tor