In [3]:
import os

os.chdir("/root/dev/playground/knowledges/3d_representation/DMTet/")
os.getcwd()

'/root/dev/playground/knowledges/3d_representation/DMTet'

In [4]:
!gpustat

[1m[37m6f37a742720d              [m  Wed Apr 17 01:20:03 2024  [1m[30m525.89.02[m
[36m[0][m [34mNVIDIA GeForce RTX 4090[m |[31m 41°C[m, [32m  7 %[m | [36m[1m[33m  945[m / [33m24564[m MB |
[36m[1][m [34mNVIDIA GeForce RTX 4090[m |[31m 37°C[m, [32m  0 %[m | [36m[1m[33m  399[m / [33m24564[m MB |


In [6]:
import torch

torch.cuda.set_device(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Explore DMTet

code from [nvdiffrec](https://github.com/NVlabs/nvdiffrec/blob/main/geometry/dmtet.py) - Extracting Triangular 3D Models, Materials, and Lighting From Images

In [263]:
import numpy as np
import torch

## Marching Tetrahedra Algorithm

In [1]:
###############################################################################
# Marching tetrahedrons implementation (differentiable), adapted from
# https://github.com/NVIDIAGameWorks/kaolin/blob/master/kaolin/ops/conversions/tetmesh.py
###############################################################################

class DMTet:
    def __init__(self):
        self.triangle_table = torch.tensor([
                [-1, -1, -1, -1, -1, -1],
                [ 1,  0,  2, -1, -1, -1],
                [ 4,  0,  3, -1, -1, -1],
                [ 1,  4,  2,  1,  3,  4],
                [ 3,  1,  5, -1, -1, -1],
                [ 2,  3,  0,  2,  5,  3],
                [ 1,  4,  0,  1,  5,  4],
                [ 4,  2,  5, -1, -1, -1],
                [ 4,  5,  2, -1, -1, -1],
                [ 4,  1,  0,  4,  5,  1],
                [ 3,  2,  0,  3,  5,  2],
                [ 1,  3,  5, -1, -1, -1],
                [ 4,  1,  2,  4,  3,  1],
                [ 3,  0,  4, -1, -1, -1],
                [ 2,  0,  1, -1, -1, -1],
                [-1, -1, -1, -1, -1, -1]
                ], dtype=torch.long, device='cuda')

        self.num_triangles_table = torch.tensor([0,1,1,2,1,2,2,1,1,2,2,1,2,1,1,0], dtype=torch.long, device='cuda')
        self.base_tet_edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype=torch.long, device='cuda')

    ###############################################################################
    # Utility functions
    ###############################################################################

    def sort_edges(self, edges_ex2):
        with torch.no_grad():
            order = (edges_ex2[:,0] > edges_ex2[:,1]).long()
            order = order.unsqueeze(dim=1)

            a = torch.gather(input=edges_ex2, index=order, dim=1)      
            b = torch.gather(input=edges_ex2, index=1-order, dim=1)  

        return torch.stack([a, b],-1)

    def map_uv(self, faces, face_gidx, max_idx):
        N = int(np.ceil(np.sqrt((max_idx+1)//2)))
        tex_y, tex_x = torch.meshgrid(
            torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"),
            torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"),
            indexing='ij'
        )

        pad = 0.9 / N

        uvs = torch.stack([
            tex_x      , tex_y,
            tex_x + pad, tex_y,
            tex_x + pad, tex_y + pad,
            tex_x      , tex_y + pad
        ], dim=-1).view(-1, 2)

        def _idx(tet_idx, N):
            x = tet_idx % N
            y = torch.div(tet_idx, N, rounding_mode='trunc')
            return y * N + x

        tet_idx = _idx(torch.div(face_gidx, 2, rounding_mode='trunc'), N)
        tri_idx = face_gidx % 2

        uv_idx = torch.stack((
            tet_idx * 4, tet_idx * 4 + tri_idx + 1, tet_idx * 4 + tri_idx + 2
        ), dim = -1). view(-1, 3)

        return uvs, uv_idx

    ###############################################################################
    # Marching tets implementation
    ###############################################################################

    def __call__(self, pos_nx3, sdf_n, tet_fx4):
        with torch.no_grad():
            occ_n = sdf_n > 0
            occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1,4)
            occ_sum = torch.sum(occ_fx4, -1)
            valid_tets = (occ_sum>0) & (occ_sum<4)
            occ_sum = occ_sum[valid_tets]

            # find all vertices
            all_edges = tet_fx4[valid_tets][:,self.base_tet_edges].reshape(-1,2)
            all_edges = self.sort_edges(all_edges)
            unique_edges, idx_map = torch.unique(all_edges,dim=0, return_inverse=True)  
            
            unique_edges = unique_edges.long()
            mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1,2).sum(-1) == 1
            mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device="cuda") * -1
            mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long,device="cuda")
            idx_map = mapping[idx_map] # map edges to verts

            interp_v = unique_edges[mask_edges]
        edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1,2,3)
        edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1,2,1)
        edges_to_interp_sdf[:,-1] *= -1

        denominator = edges_to_interp_sdf.sum(1,keepdim = True)

        edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1])/denominator
        verts = (edges_to_interp * edges_to_interp_sdf).sum(1)

        idx_map = idx_map.reshape(-1,6)

        v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device="cuda"))
        tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
        num_triangles = self.num_triangles_table[tetindex]

        # Generate triangle indices
        faces = torch.cat((
            torch.gather(input=idx_map[num_triangles == 1], dim=1, index=self.triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1,3),
            torch.gather(input=idx_map[num_triangles == 2], dim=1, index=self.triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1,3),
        ), dim=0)

        # Get global face index (static, does not depend on topology)
        num_tets = tet_fx4.shape[0]
        tet_gidx = torch.arange(num_tets, dtype=torch.long, device="cuda")[valid_tets]
        face_gidx = torch.cat((
            tet_gidx[num_triangles == 1]*2,
            torch.stack((tet_gidx[num_triangles == 2]*2, tet_gidx[num_triangles == 2]*2 + 1), dim=-1).view(-1)
        ), dim=0)

        uvs, uv_idx = self.map_uv(faces, face_gidx, num_tets*2)

        return verts, faces, uvs, uv_idx

### Init

In [46]:
triangle_table = torch.tensor([
                [-1, -1, -1, -1, -1, -1],
                [ 1,  0,  2, -1, -1, -1],
                [ 4,  0,  3, -1, -1, -1],
                [ 1,  4,  2,  1,  3,  4],
                [ 3,  1,  5, -1, -1, -1],
                [ 2,  3,  0,  2,  5,  3],
                [ 1,  4,  0,  1,  5,  4],
                [ 4,  2,  5, -1, -1, -1],
                [ 4,  5,  2, -1, -1, -1],
                [ 4,  1,  0,  4,  5,  1],
                [ 3,  2,  0,  3,  5,  2],
                [ 1,  3,  5, -1, -1, -1],
                [ 4,  1,  2,  4,  3,  1],
                [ 3,  0,  4, -1, -1, -1],
                [ 2,  0,  1, -1, -1, -1],
                [-1, -1, -1, -1, -1, -1]
                ], dtype=torch.long) # total number of surface typologies: 2^4 = 16
num_triangles_table = torch.tensor([0,1,1,2,1,2,2,1,1,2,2,1,2,1,1,0], dtype=torch.long) # unique cases: 3
base_tet_edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype=torch.long) # number of edges of a tetrohydron: 6

In [47]:
triangle_table.shape

torch.Size([16, 6])

In [48]:
num_triangles_table.shape

torch.Size([16])

In [49]:
base_tet_edges.shape

torch.Size([12])

### call

In [166]:
sdf = torch.rand((2048,)) * 2 - 1
pos = torch.rand((2048, 3)) * 2 - 1
tet_feat = torch.randint(low=0, high=10, size=(256, 4)) # 256?

In [167]:
occ = sdf > 0

In [168]:
occ_feat = occ[tet_feat.reshape(-1)].reshape(-1, 4) # (256, 4)

In [169]:
occ_sum = torch.sum(occ_feat, -1) # (256)

In [170]:
valid_tets = (occ_sum > 0) & (occ_sum < 4) # T_\text{surf}

In [171]:
occ_sum = occ_sum[valid_tets]

In [172]:
all_edges = tet_feat[valid_tets][:, base_tet_edges].reshape(-1, 2)

In [173]:
tet_feat[valid_tets].shape

torch.Size([215, 4])

In [174]:
tet_feat[valid_tets][:, base_tet_edges].shape

torch.Size([215, 12])

In [175]:
tet_feat[valid_tets][0]

tensor([1, 2, 6, 5])

In [176]:
base_tet_edges

tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3])

In [177]:
tet_feat[valid_tets][0, base_tet_edges]

tensor([1, 2, 1, 6, 1, 5, 2, 6, 2, 5, 6, 5])

In [178]:
def sort_edges(edges):
    order = (edges[:, 0] > edges[:, 1]).long() # (e) 0 or 1
    order = order.unsqueeze(dim=1) # (e, 1)

    a = torch.gather(input=edges, index=order, dim=1)
    b = torch.gather(input=edges, index=1 - order, dim=1)
    return torch.stack([a, b], -1) # (total_num_edges, 1, 2)

In [179]:
all_edges.shape

torch.Size([1290, 2])

In [180]:
all_edges[:10]

tensor([[1, 2],
        [1, 6],
        [1, 5],
        [2, 6],
        [2, 5],
        [6, 5],
        [8, 4],
        [8, 4],
        [8, 6],
        [4, 4]])

In [181]:
order = (all_edges[:10, 0] > all_edges[:10, 1]).long().unsqueeze(dim=1)

In [182]:
order

tensor([[0],
        [0],
        [0],
        [0],
        [0],
        [1],
        [1],
        [1],
        [1],
        [0]])

In [183]:
torch.gather(input=all_edges[:10], index=order, dim=1)

tensor([[1],
        [1],
        [1],
        [2],
        [2],
        [5],
        [4],
        [4],
        [6],
        [4]])

In [184]:
all_edges = sort_edges(all_edges)

In [185]:
all_edges.shape

torch.Size([1290, 1, 2])

In [186]:
all_edges[0]

tensor([[1, 2]])

In [187]:
unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)

In [188]:
unique_edges = unique_edges.long()

In [189]:
unique_edges.shape # (num_edges, 1, 2)

torch.Size([55, 1, 2])

In [190]:
idx_map.shape # (total_num_edges) original input idx

torch.Size([1290])

In [191]:
occ.shape

torch.Size([2048])

In [192]:
mask_edges = occ[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 # sign of SDF changed

In [193]:
mask_edges.shape

torch.Size([55])

In [194]:
mask_edges

tensor([False, False, False, False,  True,  True,  True,  True, False, False,
        False, False, False,  True,  True,  True,  True, False, False, False,
        False,  True,  True,  True,  True, False, False, False,  True,  True,
         True,  True, False, False, False, False, False, False,  True,  True,
        False, False, False,  True,  True, False, False,  True,  True, False,
         True,  True, False, False, False])

In [195]:
mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long) * -1 # (num_edges)

In [196]:
mapping.shape

torch.Size([55])

In [197]:
mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long)

In [198]:
torch.arange(mask_edges.sum(), dtype=torch.long)

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23])

In [199]:
for i in range(10):
    print(mapping[i])

tensor(-1)
tensor(-1)
tensor(-1)
tensor(-1)
tensor(0)
tensor(1)
tensor(2)
tensor(3)
tensor(-1)
tensor(-1)


In [200]:
idx_map

tensor([11, 15, 14,  ..., 24, 31, 20])

In [201]:
idx_map = mapping[idx_map] # map edges to verts

In [202]:
idx_map

tensor([-1,  6,  5,  ..., 11, 15, -1])

In [203]:
idx_map.shape

torch.Size([1290])

In [204]:
interp_v = unique_edges[mask_edges]

In [205]:
interp_v.shape

torch.Size([24, 1, 2])

In [206]:
edges_to_interp = pos[interp_v.reshape(-1)].reshape(-1, 2, 3) # (num_v, 2, 3)
edges_to_interp_sdf = sdf[interp_v.reshape(-1)].reshape(-1, 2, 1) # (num_v, 2, 1)
edges_to_interp_sdf[:, -1] *= -1 # change the second vertices of the edge (for interpolation)

In [207]:
edges_to_interp.shape

torch.Size([24, 2, 3])

In [208]:
edges_to_interp_sdf.shape

torch.Size([24, 2, 1])

In [209]:
denominator = edges_to_interp_sdf.sum(1, keepdim=True) # s_b + (-s_a)

In [210]:
edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator

In [211]:
verts = (edges_to_interp * edges_to_interp_sdf).sum(1) # (v_a s_b - v_b s_a) / (s_b - s_a)

In [212]:
idx_map.shape

torch.Size([1290])

In [213]:
idx_map = idx_map.reshape(-1, 6) # (total_num_edges) → (total_num_edges / 6, 6)

In [214]:
idx_map.shape

torch.Size([215, 6])

In [215]:
v_id = torch.pow(2, torch.arange(4, dtype=torch.long))

In [216]:
occ_feat[valid_tets].shape

torch.Size([215, 4])

In [219]:
occ_feat[valid_tets][:5]

tensor([[False, False,  True,  True],
        [False,  True,  True,  True],
        [ True, False, False,  True],
        [False, False,  True,  True],
        [ True, False,  True,  True]])

In [218]:
(occ_feat[valid_tets] * v_id.unsqueeze(0))[:5]

tensor([[0, 0, 4, 8],
        [0, 2, 4, 8],
        [1, 0, 0, 8],
        [0, 0, 4, 8],
        [1, 0, 4, 8]])

In [220]:
tet_idx = (occ_feat[valid_tets] * v_id.unsqueeze(0)).sum(-1)

In [222]:
tet_idx[:5]

tensor([12, 14,  9, 12, 13])

In [229]:
triangle_table.shape

torch.Size([16, 6])

In [231]:
num_triangles = num_triangles_table[tet_idx]

In [235]:
num_triangles.shape # (num_triangle)

torch.Size([215])

In [236]:
# Generate triangle indices
faces = torch.cat((
    torch.gather(input=idx_map[num_triangles == 1], dim=1, index=triangle_table[tet_idx[num_triangles == 1]][:, :3]).reshape(-1,3),
    torch.gather(input=idx_map[num_triangles == 2], dim=1, index=triangle_table[tet_idx[num_triangles == 2]][:, :6]).reshape(-1,3),
), dim=0)

- num_triangles == 1: unique typology is 1 (0 means no mesh)
- tet_idx\[unique_typology\]: which triangle mesh out of the total number of possible typologies (16)

In [237]:
faces.shape

torch.Size([307, 3])

In [238]:
faces[:5]

tensor([[20, 16, 16],
        [14, 13, 13],
        [ 6,  6, 21],
        [ 1, 13,  1],
        [ 4,  4,  6]])

In [239]:
# Get global face index (static, does not depend on topology)
num_tets = tet_feat.shape[0]
tet_gidx = torch.arange(num_tets, dtype=torch.long)[valid_tets]

In [241]:
tet_gidx

tensor([  0,   1,   2,   3,   4,   6,   7,   8,   9,  10,  11,  12,  13,  14,
         15,  18,  19,  20,  21,  22,  23,  25,  26,  28,  29,  30,  31,  32,
         34,  35,  36,  37,  38,  39,  40,  41,  42,  43,  44,  45,  47,  48,
         50,  51,  52,  53,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
         65,  66,  67,  68,  71,  72,  75,  77,  79,  80,  81,  82,  83,  84,
         85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  97,  99, 100,
        101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114,
        115, 116, 117, 118, 120, 122, 124, 125, 126, 127, 128, 129, 130, 131,
        132, 136, 137, 138, 142, 143, 145, 146, 147, 148, 149, 150, 151, 152,
        153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166,
        167, 168, 169, 170, 171, 173, 174, 175, 176, 177, 179, 180, 181, 183,
        184, 185, 187, 188, 189, 190, 191, 193, 194, 195, 196, 197, 198, 199,
        200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 212, 2

In [245]:
valid_tets[5]

tensor(False)

In [249]:
tet_gidx[num_triangles == 1].shape

torch.Size([123])

In [250]:
tet_gidx[num_triangles == 2].shape

torch.Size([92])

In [254]:
(tet_gidx[num_triangles == 2] * 2)[:5]

tensor([ 0,  4,  6, 12, 16])

In [256]:
torch.stack((tet_gidx[num_triangles == 2] * 2, tet_gidx[num_triangles == 2] * 2 + 1), dim=-1).shape

torch.Size([92, 2])

In [246]:
face_gidx = torch.cat((
    tet_gidx[num_triangles == 1] * 2,
    torch.stack((tet_gidx[num_triangles == 2] * 2, tet_gidx[num_triangles == 2] * 2 + 1), dim=-1).view(-1)
), dim=0)

In [258]:
face_gidx.shape # 123 + 184 = 307

torch.Size([307])

In [261]:
def map_uv(faces, face_gidx, max_idx):
    N = int(np.ceil(np.sqrt((max_idx+1)//2)))
    tex_y, tex_x = torch.meshgrid(
        torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32),
        torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32),
        indexing='ij'
    )

    pad = 0.9 / N

    uvs = torch.stack([
        tex_x      , tex_y,
        tex_x + pad, tex_y,
        tex_x + pad, tex_y + pad,
        tex_x      , tex_y + pad
    ], dim=-1).view(-1, 2)

    def _idx(tet_idx, N):
        x = tet_idx % N
        y = torch.div(tet_idx, N, rounding_mode='trunc')
        return y * N + x

    tet_idx = _idx(torch.div(face_gidx, 2, rounding_mode='trunc'), N)
    tri_idx = face_gidx % 2

    uv_idx = torch.stack((
        tet_idx * 4, tet_idx * 4 + tri_idx + 1, tet_idx * 4 + tri_idx + 2
    ), dim = -1). view(-1, 3)

    return uvs, uv_idx

In [281]:
y, x = torch.meshgrid(torch.linspace(0, 1, 6, dtype=torch.float32), torch.linspace(0, 1, 6, dtype=torch.float32), indexing="ij")

In [282]:
y

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
        [0.4000, 0.4000, 0.4000, 0.4000, 0.4000, 0.4000],
        [0.6000, 0.6000, 0.6000, 0.6000, 0.6000, 0.6000],
        [0.8000, 0.8000, 0.8000, 0.8000, 0.8000, 0.8000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]])

In [283]:
x

tensor([[0.0000, 0.2000, 0.4000, 0.6000, 0.8000, 1.0000],
        [0.0000, 0.2000, 0.4000, 0.6000, 0.8000, 1.0000],
        [0.0000, 0.2000, 0.4000, 0.6000, 0.8000, 1.0000],
        [0.0000, 0.2000, 0.4000, 0.6000, 0.8000, 1.0000],
        [0.0000, 0.2000, 0.4000, 0.6000, 0.8000, 1.0000],
        [0.0000, 0.2000, 0.4000, 0.6000, 0.8000, 1.0000]])

In [264]:
uvs, uv_idx = map_uv(faces, face_gidx, num_tets*2)

In [265]:
uvs.shape

torch.Size([1024, 2])

In [266]:
uv_idx.shape

torch.Size([307, 3])

## DMTet Geometry Interface

In [None]:
from torch import nn

class DMTetGeometry(nn.Module):
    def __init__(self, grid_res, scale, FLAGS):
        super().__init__()

        self.FLAGS = FLAGS
        self.grid_res = grid_res
        self.marching_tets = DMTet() # verts, faces, uvs, uv_idx

        # tets = np.load("data/tets/{}_tets.npz".format(self.grid_res))
        self.verts = torch.tensor(tets["vertices"], dtype=torch.float32) * scale
        self.indices = torch.tensor(tets["indices"], dtype=torch.long)
        
        # generate edges
        with torch.no_grad():
            edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype = torch.long, device = "cuda")
            all_edges = self.indices[:,edges].reshape(-1,2)
            all_edges_sorted = torch.sort(all_edges, dim=1)[0]
            self.all_edges = torch.unique(all_edges_sorted, dim=0)

        # random initialization
        sdf = torch.rand_like(self.verts[:, 0]) - 0.1

        self.sdf = torch.nn.Parameter(sdf.clone().detach(), requires_grad=True)
        self.register_parameter("sdf", self.sdf)

        self.deform = torch.nn.Parameter(torch.zeros_like(self.verts), requires_grad=True)
        self.register_parameter("deform", self.deform)

    @torch.no_grad()
    def getAABB(self):
        return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values

    def getMesh(self, material):
        # run DMTet to get a base mesh
        v_deformed = self.verts + 2 / (self.grid_res * 2) * torch.tanh(self.deform) # tanh: real value to [-1, 1]
        verts, faces, uvs, uv_idx = self.marching_tets(v_deformed, self.sdf, self.indices) # pos, sdf, tet
        imesh = mesh.Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material)

        # run mesh operations to generate tangent space
        imesh = mesh.auto_normals(imesh)
        imesh = mesh.compute_tangents(imesh)
        return imesh

    def render(self, glctx, target, lgt, opt_material, bsdf=None):
        opt_mesh = self.getMesh(opt_material)
        return render.render_mesh(glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'], 
                                        msaa=True, background=target['background'], bsdf=bsdf)


    def tick(self, glctx, target, lgt, opt_material, loss_fn, iteration):

        # ==============================================================================================
        #  Render optimizable object with identical conditions
        # ==============================================================================================
        buffers = self.render(glctx, target, lgt, opt_material)

        # ==============================================================================================
        #  Compute loss
        # ==============================================================================================
        t_iter = iteration / self.FLAGS.iter

        # Image-space loss, split into a coverage component and a color component
        color_ref = target['img']
        img_loss = torch.nn.functional.mse_loss(buffers['shaded'][..., 3:], color_ref[..., 3:]) 
        img_loss = img_loss + loss_fn(buffers['shaded'][..., 0:3] * color_ref[..., 3:], color_ref[..., 0:3] * color_ref[..., 3:])

        # SDF regularizer
        sdf_weight = self.FLAGS.sdf_regularizer - (self.FLAGS.sdf_regularizer - 0.01)*min(1.0, 4.0 * t_iter)
        reg_loss = sdf_reg_loss(self.sdf, self.all_edges).mean() * sdf_weight # Dropoff to 0.01

        # Albedo (k_d) smoothnesss regularizer
        reg_loss += torch.mean(buffers['kd_grad'][..., :-1] * buffers['kd_grad'][..., -1:]) * 0.03 * min(1.0, iteration / 500)

        # Visibility regularizer
        reg_loss += torch.mean(buffers['occlusion'][..., :-1] * buffers['occlusion'][..., -1:]) * 0.001 * min(1.0, iteration / 500)

        # Light white balance regularizer
        reg_loss = reg_loss + lgt.regularizer() * 0.005

        return img_loss, reg_loss

# Use DMTet

- Tutorial from [link](https://github.com/NVIDIAGameWorks/kaolin/tree/master)