In [1]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import trimesh
from skimage import measure
import meshplot as mp
from torch.utils.data import DataLoader, Dataset
import os
from datetime import datetime

In [2]:
# tag shape index to each x=xyz, y=sdf
# horizontally concatenate for multiple chapes
# returns shape index, x=xyz (1 by 3), y=sdf
class ChairDataset(Dataset):
    def __init__(self, file_paths):
        for idx, file_path in enumerate(file_paths): #TODO: move to getitem for autoencoder
            training_set = np.load(file_path)
            points = training_set['points']
            points = torch.from_numpy(points.astype(np.float32)) 
            idx_tensor = torch.ones((points.shape[0], 1), dtype=torch.int) * (idx) # 0 based indexing shapes
            
            sdf = training_set['sdf']
            sdf = torch.from_numpy(sdf.astype(np.float32)) 
            sdf = sdf.view(sdf.shape[0], 1)
            
            n_samples = sdf.shape[0] 

            if idx == 0:
                self.shape_idx = idx_tensor
                self.x = points
                self.y = sdf
                self.n_samples = n_samples
            else:
                self.shape_idx = torch.cat((self.shape_idx, idx_tensor), dim = 0) # concatenate vertically
                self.x = torch.cat((self.x, points), dim = 0) #concatenate vertically
                self.y = torch.cat((self.y, sdf), dim = 0) #concatenate vertically
                self.n_samples += n_samples
            
    def __getitem__(self, index):
        return self.shape_idx[index], self.x[index], self.y[index]
    
    def __len__(self):
        return self.n_samples

In [3]:
# autodecoder MLP class
class MLP(nn.Module):
    def __init__(self, n_shapes, shape_code_length, n_inner_nodes):
        super(MLP, self).__init__()
        self.shape_code_length = shape_code_length
        self.shape_codes = nn.Embedding(n_shapes, shape_code_length, max_norm=0.01) # shape code as an embedding # TODO: take this outside 
        self.linear1 = nn.Linear(3 + shape_code_length, n_inner_nodes) # (x, y, z) + shape code 
        self.linear2 = nn.Linear(n_inner_nodes, n_inner_nodes)
        self.linear3 = nn.Linear(n_inner_nodes, n_inner_nodes)
        self.linear4 = nn.Linear(n_inner_nodes, 1) # output a SDF value
        self.relu = nn.ReLU()

    def forward(self, shape_idx, x):
        shape_code = self.shape_codes(shape_idx.view(1, -1))
        shape_code = shape_code.view(-1, self.shape_code_length)
        shape_code_with_xyz = torch.cat((x, shape_code), dim=1) # concatenate horizontally
        
        out = self.linear1(shape_code_with_xyz)
        out = self.relu(out)
        out = self.linear2(out)
        out = self.relu(out)
        out = self.linear3(out)
        out = self.relu(out)
        out = self.linear4(out)
        
        return out

In [4]:
# loading files
n_files = 10 
file_paths = []
main_dir = '../data/03001627_sdfs/'
for sub_dir in os.scandir(main_dir):
    if sub_dir.is_dir():
        for file in os.listdir(main_dir + sub_dir.name):
            file_paths.append(main_dir + sub_dir.name + '/' + file) if file.endswith("sdf_samples.npz") else None
    if len(file_paths) == n_files:
        break

print(file_paths)

['../data/03001627_sdfs/4e664dae1bafe49f19fb4103277a6b93/sdf_samples.npz', '../data/03001627_sdfs/65840c85162994d990de7d30a74bbb6b/sdf_samples.npz', '../data/03001627_sdfs/c7ae4cc12a7bc2581fa16f9a5527bb27/sdf_samples.npz', '../data/03001627_sdfs/e9e224bc0a0787d8320f10afdfbaa18/sdf_samples.npz', '../data/03001627_sdfs/68c7f82dd1e1634d9338458f802f5ad7/sdf_samples.npz', '../data/03001627_sdfs/58a7b826ed562b7bb0957d845ac33749/sdf_samples.npz', '../data/03001627_sdfs/e4c866b5dd958cd0803d0f5bac2abe4c/sdf_samples.npz', '../data/03001627_sdfs/9a91a491a9e74ab132c074e5313866f2/sdf_samples.npz', '../data/03001627_sdfs/670b6b7d3fe6e4a77c5a5393686fdcfc/sdf_samples.npz', '../data/03001627_sdfs/2d701c588b3bbbc458c88d30f502a452/sdf_samples.npz']


In [5]:
# training main loop
dataset = ChairDataset(file_paths)
dataloader = DataLoader(dataset=dataset, batch_size=2048, shuffle=True) #TODO: can specify num_workers

n_epochs = 1000
learning_rate = 0.001
# hidden_size = 500 
# hidden_layer_act_func = "ReLU"
# output_layer_act_func = None

model = MLP(len(file_paths), 256, 256)

criterion = nn.MSELoss()

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

model.train(True) # prep model for training, needed for modules like Dropout

now = datetime.now() 
now = now.strftime("%m%d%Y_%H%M%S")
print(f'datetime now: {now}')

for epoch in range(n_epochs):
    train_loss = 0.0 # monitor training loss

    for shape_idx, xyz, sdf in dataloader:
        optimizer.zero_grad() 
        sdf_pred = model(shape_idx, xyz) 
        loss = criterion(torch.clamp(sdf_pred, -0.1, 0.1), torch.clamp(sdf, -0.1, 0.1)) 
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()*shape_idx.size(0) # update running training loss
        
    train_loss = train_loss/len(dataloader.dataset) # calculate average loss over one epoch
    if epoch%10 == 0:
        print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch+1, train_loss))
        torch.save(model.state_dict(), f'./models/autodecoder_' + now)

In [None]:
filename = './models/multiple shapes_08052022_073446'
model = MLP(len(file_paths), 256, 256)
model.load_state_dict(torch.load(filename))
model.eval()

x = np.linspace(-1, 1, 200, dtype=np.float32)
y = np.linspace(-1, 1, 200, dtype=np.float32)
z = np.linspace(-1, 1, 200, dtype=np.float32)
P = np.vstack(np.meshgrid(x,y,z)).reshape(3,-1).T  # format: [[x1, y1, z1], [x1, y1, z2], [] ...]
P = torch.from_numpy(P)

shape_idx = 0
shape_idx_tensor = torch.ones((P.shape[0], 1), dtype=torch.int) * shape_idx

volume = model(shape_idx_tensor,P).view(len(x), len(y), len(z)) 
volume = volume.detach().numpy()

verts, faces, normals, values = measure.marching_cubes(volume, 0)
mp.plot(verts, faces)

In [7]:
from functools import partial
from multiprocessing.pool import ThreadPool
import multiprocessing
import itertools
import time

filename = './models/autodecoder_08052022_073446'
model = MLP(len(file_paths), 256, 256)
model.load_state_dict(torch.load(filename))
model.eval()
shape_idx = 0

WORKERS = multiprocessing.cpu_count()

def _cartesian_product(*arrays):
    la = len(arrays)
    dtype = np.result_type(*arrays)
    arr = np.empty([len(a) for a in arrays] + [la], dtype=dtype)
    for i, a in enumerate(np.ix_(*arrays)):
        arr[...,i] = a
    return arr.reshape(-1, la)

def _worker(job):
    X, Y, Z = job
    # P = np.vstack(np.meshgrid(X,Y,Z)).reshape(3,-1).T
    P = _cartesian_product(X, Y, Z)
    P = torch.from_numpy(P)

    shape_idx_tensor = torch.ones((P.shape[0], 1), dtype=torch.int) * shape_idx
    volume = model(shape_idx_tensor, P).reshape((len(X), len(Y), len(Z)))
    volume = volume.detach().numpy()

    verts, faces, normals, values = measure.marching_cubes(volume, 0)
    # verts = verts[faces]
    scale = np.array([X[1] - X[0], Y[1] - Y[0], Z[1] - Z[0]])
    offset = np.array([X[0], Y[0], Z[0]])
    # verts = verts * scale + offset
    # verts = verts[faces]
    return verts*scale+offset, faces

X = np.linspace(-1, 1, 201)
Y = np.linspace(-1, 1, 201)
Z = np.linspace(-1, 1, 201)

batch_size = 101
s = batch_size
Xs = [X[i:i+s+1] for i in range(0, len(X), s)]
Ys = [Y[i:i+s+1] for i in range(0, len(Y), s)]
Zs = [Z[i:i+s+1] for i in range(0, len(Z), s)]

batches = list(itertools.product(Xs, Ys, Zs))
num_batches = len(batches)
num_samples = sum(len(xs) * len(ys) * len(zs) for xs, ys, zs in batches)
print(num_batches, num_samples)

pool = ThreadPool(WORKERS)
f = partial(_worker) # _worker, sdf callables

verts_combined = []
faces_combined = []

for verts, faces in pool.imap(f, batches):
    faces_combined.extend(np.array(faces) + len(verts_combined))
    verts_combined.extend(np.array(verts))
    
    
mp.plot(np.array(verts_combined), np.array(faces_combined))

8 8242408


TypeError: expected Tensor as element 0 in argument 0, but got numpy.ndarray

In [7]:
# compare with ground truth
shape_idx = 0
mesh = trimesh.load(file_paths[shape_idx].split('sdf_samples.npz')[0] + 'mesh.obj')
mp.plot(mesh.vertices, mesh.faces, c=np.array([0, 0.9, 0.9]))



Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0, 0.0,…

<meshplot.Viewer.Viewer at 0x7f16d785e880>

In [None]:
# filename = f'./models/multiple shapes_' + now
# loaded_model = MLP(len(file_paths), 256, 256)
# loaded_model.load_state_dict(torch.load(filename))
# loaded_model.eval()

x = np.linspace(-1, 1, 150, dtype=np.float32)
y = np.linspace(-1, 1, 150, dtype=np.float32)
z = np.linspace(-1, 1, 150, dtype=np.float32)
P = np.vstack(np.meshgrid(x,y,z)).reshape(3,-1).T  # format: [[x1, y1, z1], [x1, y1, z2], [] ...]
P = torch.from_numpy(P)

shape_idx = 1
shape_idx_tensor = torch.ones((P.shape[0], 1), dtype=torch.int) * (shape_idx)

volume = model(shape_idx_tensor,P).view(len(x), len(y), len(z)) 
volume = volume.detach().numpy()

verts, faces, normals, values = measure.marching_cubes(volume, 0)
mp.plot(verts, faces)

In [None]:
# compare with ground truth
shape_idx = 1
mesh = trimesh.load(file_paths[shape_idx].split('sdf_samples.npz')[0] + 'mesh.obj')
mp.plot(mesh.vertices, mesh.faces, c=np.array([0, 0.9, 0.9]))

In [None]:
# filename = f'./models/multiple shapes_' + now
# loaded_model = MLP(len(file_paths), 256, 256)
# loaded_model.load_state_dict(torch.load(filename))
# loaded_model.eval()

x = np.linspace(-1, 1, 150, dtype=np.float32)
y = np.linspace(-1, 1, 150, dtype=np.float32)
z = np.linspace(-1, 1, 150, dtype=np.float32)
P = np.vstack(np.meshgrid(x,y,z)).reshape(3,-1).T  # format: [[x1, y1, z1], [x1, y1, z2], [] ...]
P = torch.from_numpy(P)

shape_idx = 2
shape_idx_tensor = torch.ones((P.shape[0], 1), dtype=torch.int) * (shape_idx)

volume = model(shape_idx_tensor,P).view(len(x), len(y), len(z)) 
volume = volume.detach().numpy()

verts, faces, normals, values = measure.marching_cubes(volume, 0)
mp.plot(verts, faces)

In [None]:
# compare with ground truth
shape_idx = 2
mesh = trimesh.load(file_paths[shape_idx].split('sdf_samples.npz')[0] + 'mesh.obj')
mp.plot(mesh.vertices, mesh.faces, c=np.array([0, 0.9, 0.9]))

In [None]:
# filename = f'./models/multiple shapes_' + now
# loaded_model = MLP(len(file_paths), 256, 256)
# loaded_model.load_state_dict(torch.load(filename))
# loaded_model.eval()

x = np.linspace(-1, 1, 150, dtype=np.float32)
y = np.linspace(-1, 1, 150, dtype=np.float32)
z = np.linspace(-1, 1, 150, dtype=np.float32)
P = np.vstack(np.meshgrid(x,y,z)).reshape(3,-1).T  # format: [[x1, y1, z1], [x1, y1, z2], [] ...]
P = torch.from_numpy(P)

shape_idx = 5
shape_idx_tensor = torch.ones((P.shape[0], 1), dtype=torch.int) * (shape_idx)

volume = model(shape_idx_tensor,P).view(len(x), len(y), len(z)) 
volume = volume.detach().numpy()

verts, faces, normals, values = measure.marching_cubes(volume, 0)
mp.plot(verts, faces)

In [None]:
# compare with ground truth
shape_idx = 5
mesh = trimesh.load(file_paths[shape_idx].split('sdf_samples.npz')[0] + 'mesh.obj')
mp.plot(mesh.vertices, mesh.faces, c=np.array([0, 0.9, 0.9]))

In [None]:
# filename = f'./models/multiple shapes_' + now
# loaded_model = MLP(len(file_paths), 256, 256)
# loaded_model.load_state_dict(torch.load(filename))
# loaded_model.eval()

x = np.linspace(-1, 1, 150, dtype=np.float32)
y = np.linspace(-1, 1, 150, dtype=np.float32)
z = np.linspace(-1, 1, 150, dtype=np.float32)
P = np.vstack(np.meshgrid(x,y,z)).reshape(3,-1).T  # format: [[x1, y1, z1], [x1, y1, z2], [] ...]
P = torch.from_numpy(P)

shape_idx = 6
shape_idx_tensor = torch.ones((P.shape[0], 1), dtype=torch.int) * (shape_idx)

volume = model(shape_idx_tensor,P).view(len(x), len(y), len(z)) 
volume = volume.detach().numpy()

verts, faces, normals, values = measure.marching_cubes(volume, 0)
mp.plot(verts, faces)

In [None]:
# compare with ground truth
shape_idx = 6
mesh = trimesh.load(file_paths[shape_idx].split('sdf_samples.npz')[0] + 'mesh.obj')
mp.plot(mesh.vertices, mesh.faces, c=np.array([0, 0.9, 0.9]))

In [None]:
# filename = f'./models/multiple shapes_' + now
# loaded_model = MLP(len(file_paths), 256, 256)
# loaded_model.load_state_dict(torch.load(filename))
# loaded_model.eval()

x = np.linspace(-1, 1, 150, dtype=np.float32)
y = np.linspace(-1, 1, 150, dtype=np.float32)
z = np.linspace(-1, 1, 150, dtype=np.float32)
P = np.vstack(np.meshgrid(x,y,z)).reshape(3,-1).T  # format: [[x1, y1, z1], [x1, y1, z2], [] ...]
P = torch.from_numpy(P)

shape_idx = 7
shape_idx_tensor = torch.ones((P.shape[0], 1), dtype=torch.int) * (shape_idx)

volume = model(shape_idx_tensor,P).view(len(x), len(y), len(z)) 
volume = volume.detach().numpy()

verts, faces, normals, values = measure.marching_cubes(volume, 0)
mp.plot(verts, faces)

In [None]:
# compare with ground truth
shape_idx = 7
mesh = trimesh.load(file_paths[shape_idx].split('sdf_samples.npz')[0] + 'mesh.obj')
mp.plot(mesh.vertices, mesh.faces, c=np.array([0, 0.9, 0.9]))

In [None]:
# filename = f'./models/multiple shapes_' + now
# loaded_model = MLP(len(file_paths), 256, 256)
# loaded_model.load_state_dict(torch.load(filename))
# loaded_model.eval()

x = np.linspace(-1, 1, 300, dtype=np.float32)
y = np.linspace(-1, 1, 300, dtype=np.float32)
z = np.linspace(-1, 1, 300, dtype=np.float32)
P = np.vstack(np.meshgrid(x,y,z)).reshape(3,-1).T  # format: [[x1, y1, z1], [x1, y1, z2], [] ...]
P = torch.from_numpy(P)

shape_idx = 9
shape_idx_tensor = torch.ones((P.shape[0], 1), dtype=torch.int) * (shape_idx)

volume = model(shape_idx_tensor,P).view(len(x), len(y), len(z)) 
volume = volume.detach().numpy()

verts, faces, normals, values = measure.marching_cubes(volume, 0)
mp.plot(verts, faces)

In [None]:
# compare with ground truth
shape_idx = 9
mesh = trimesh.load(file_paths[shape_idx].split('sdf_samples.npz')[0] + 'mesh.obj')
mp.plot(mesh.vertices, mesh.faces, c=np.array([0, 0.9, 0.9]))