Skip to content

Commit

Permalink
add mt tutorial (#492)
Browse files Browse the repository at this point in the history
Signed-off-by: Frank Shen <frshen@nvidia.com>

Co-authored-by: Frank Shen <frshen@nvidia.com>
  • Loading branch information
frankshen07 and frankshen07 committed Dec 15, 2021
1 parent 4ac4b10 commit 2428ae8
Show file tree
Hide file tree
Showing 10 changed files with 444 additions and 0 deletions.
Binary file added examples/samples/128_tets_0.npz
Binary file not shown.
Binary file added examples/samples/128_tets_1.npz
Binary file not shown.
Binary file added examples/samples/128_tets_2.npz
Binary file not shown.
Binary file added examples/samples/128_tets_3.npz
Binary file not shown.
Binary file added examples/samples/128_verts.npz
Binary file not shown.
Binary file added examples/samples/bear_pointcloud.usd
Binary file not shown.
Binary file added examples/samples/dash3d_mesh.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/samples/dash3d_pcd.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
356 changes: 356 additions & 0 deletions examples/tutorial/dmtet_tutorial.ipynb

Large diffs are not rendered by default.

88 changes: 88 additions & 0 deletions examples/tutorial/network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import torch
from tqdm import tqdm

# MLP + Positional Encoding
class Decoder(torch.nn.Module):
def __init__(self, input_dims = 3, internal_dims = 128, output_dims = 4, hidden = 5, multires = 2):
super().__init__()
self.embed_fn = None
if multires > 0:
embed_fn, input_ch = get_embedder(multires)
self.embed_fn = embed_fn
input_dims = input_ch

net = (torch.nn.Linear(input_dims, internal_dims, bias=False), torch.nn.ReLU())
for i in range(hidden-1):
net = net + (torch.nn.Linear(internal_dims, internal_dims, bias=False), torch.nn.ReLU())
net = net + (torch.nn.Linear(internal_dims, output_dims, bias=False),)
self.net = torch.nn.Sequential(*net)

def forward(self, p):
if self.embed_fn is not None:
p = self.embed_fn(p)
out = self.net(p)
return out

def pre_train_sphere(self, iter):
print ("Initialize SDF to sphere")
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(list(self.parameters()), lr=1e-4)

for i in tqdm(range(iter)):
p = torch.rand((1024,3), device='cuda') - 0.5
ref_value = torch.sqrt((p**2).sum(-1)) - 0.3
output = self(p)
loss = loss_fn(output[...,0], ref_value)
optimizer.zero_grad()
loss.backward()
optimizer.step()

print("Pre-trained MLP", loss.item())


# Positional Encoding from https://github.com/yenchenlin/nerf-pytorch/blob/1f064835d2cca26e4df2d7d130daa39a8cee1795/run_nerf_helpers.py
class Embedder:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_embedding_fn()

def create_embedding_fn(self):
embed_fns = []
d = self.kwargs['input_dims']
out_dim = 0
if self.kwargs['include_input']:
embed_fns.append(lambda x : x)
out_dim += d

max_freq = self.kwargs['max_freq_log2']
N_freqs = self.kwargs['num_freqs']

if self.kwargs['log_sampling']:
freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
else:
freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)

for freq in freq_bands:
for p_fn in self.kwargs['periodic_fns']:
embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
out_dim += d

self.embed_fns = embed_fns
self.out_dim = out_dim

def embed(self, inputs):
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)

def get_embedder(multires):
embed_kwargs = {
'include_input' : True,
'input_dims' : 3,
'max_freq_log2' : multires-1,
'num_freqs' : multires,
'log_sampling' : True,
'periodic_fns' : [torch.sin, torch.cos],
}

embedder_obj = Embedder(**embed_kwargs)
embed = lambda x, eo=embedder_obj : eo.embed(x)
return embed, embedder_obj.out_dim

0 comments on commit 2428ae8

Please sign in to comment.