read the paper --> Occupancy Networks: Learning 3D Reconstruction in Function Space https://arxiv.org/abs/1812.03828

In [None]:
!pip install livelossplot --quiet

import os
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from livelossplot import PlotLosses
from skimage.measure import marching_cubes_lewiner as marching_cubes    # Lewiner et al. algorithm is faster, resolves ambiguities, and guarantees topologically correct results

%matplotlib inline
import matplotlib.pyplot as plt

# set random seed for reproducability
np.random.seed(42)

data_dir = 'data'
out_dir = 'output'

for d in [data_dir, out_dir]:
  os.makedirs(d, exist_ok=True)

In [None]:
#  points.npz download here: https://1drv.ms/u/s!AhnVhbVlzYkKgQeFmSuewkQcEJy_?e=ycVHcx
# !wget is not a good idea with onedrive 

In [50]:
def load_data(file_path):
  ''' Load points and occupancy values from file.

  Args:
  file_path (string): path to file
  '''
  data_dict = np.load(file_path)
  points = data_dict['points']
  occupancies = data_dict['occupancies']

  # Unpack data format of occupancies
  occupancies = np.unpackbits(occupancies)[:points.shape[0]]
  occupancies = occupancies.astype(np.float32)

  # Align z-axis with top of object
  points = np.stack([points[:, 0], -points[:, 2], points[:, 1]], 1)

  return points, occupancies

In [51]:
def visualize_occupancy(points, occupancies, n=50000):
  ''' Visualize points and occupancy values.

  Args:
  points (torch.Tensor or np.ndarray): 3D coordinates of the points
  occupancies (torch.Tensor or np.ndarray): occupancy values for the points
  n (int): maximum number of points to visualize
  '''
  # if needed convert torch.tensor to np.ndarray
  if isinstance(points, torch.Tensor):
    points = points.detach().cpu().numpy()
  if isinstance(occupancies, torch.Tensor):
    occupancies = occupancies.detach().cpu().numpy()

  fig = plt.figure()
  ax = fig.add_subplot(projection='3d')

  n = min(len(points), n)

  # visualize a random subset of n points
  idcs = np.random.randint(0, len(points), n)
  points = points[idcs]
  occupancies = occupancies[idcs]

  # define colors
  red = np.array([1,0,0,0.5]).reshape(1, 4).repeat(n,0)     # plot occupied points in red with alpha=0.5
  blue = np.array([0,0,1,0.01]).reshape(1, 4).repeat(n,0)   # plot free points in blue with alpha=0.01
  occ = occupancies.reshape(n, 1).repeat(4, 1)              # reshape to RGBA format to determine color

  color = np.where(occ == 1, red, blue)                     # occ=1 -> red, occ=0 -> blue
  
  # plot the points
  ax.scatter(*points.transpose(), color=color)

  # make it pretty
  ax.set_xlabel('X')
  ax.set_ylabel('Y')
  ax.set_zlabel('Z')
  
  ax.set_xlim(-0.5, 0.5)
  ax.set_ylim(-0.5, 0.5)
  ax.set_zlim(-0.5, 0.5)

  plt.show()

In [None]:
points, occupancies = load_data('./points.npz')
visualize_occupancy(points, occupancies)

In [53]:
def get_train_val_split(points, occupancies):
  ''' Split data into train and validation set.
  
  Args:
  points (torch.Tensor or np.ndarray): 3D coordinates of the points
  occupancies (torch.Tensor or np.ndarray): occupancy values for the points
  '''
  n_train = int(0.8*len(points))
  data = np.concatenate([points, occupancies.reshape(-1, 1)], 1)
  np.random.shuffle(data)     # randomly shuffles data along first axis

  train_data, val_data = np.split(data, [n_train])
  train_points, train_occs = train_data[:, :3], train_data[:, 3]
  val_points, val_occs = val_data[:, :3], val_data[:, 3]

  # this converts the points and labels from numpy.ndarray to a pytorch dataset
  train_set = torch.utils.data.TensorDataset(torch.from_numpy(train_points), torch.from_numpy(train_occs))
  val_set = torch.utils.data.TensorDataset(torch.from_numpy(val_points), torch.from_numpy(val_occs))
  return train_set, val_set

train_set, val_set = get_train_val_split(points, occupancies)

In [54]:
batch_size = 64

train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=batch_size, num_workers=1, shuffle=True, drop_last=True       # randomly shuffle the training data and potentially drop last batch
)
val_loader = torch.utils.data.DataLoader(
    val_set, batch_size=batch_size, num_workers=1, shuffle=False, drop_last=False        # do not shuffle validation set and do not potentially drop last batch
)

In [None]:
for loader in [train_loader, val_loader]:
  check_points, check_occs = [], []
  
  for pts, occs in train_loader:
    check_points.extend(pts)
    check_occs.extend(occs)
    if len(check_points) >= 10000:      # only visualize some points
      break
  
  check_points, check_occs = torch.stack(check_points), torch.stack(check_occs)
  visualize_occupancy(check_points, check_occs)

In [56]:
class OccNet(nn.Module):
  """ Network to predict an occupancy value for every 3D location. 
  
  Args:
  size_h (int): hidden dimension
  n_hidden (int): number of hidden layers
  """
  def __init__(self, size_h=64, n_hidden=4):
    super().__init__()
    # Attributes
    size_in = 3
    size_out = 1
    actvn = nn.ReLU()

    # Modules
    layers = []

    # first layer
    layers.extend([
      nn.Linear(size_in, size_h),
      actvn,
    ])

    # hidden layers
    for _ in range(n_hidden):
      layers.extend([
        nn.Linear(size_h, size_h),
        actvn,
      ])
    
    # last layer
    layers.append(nn.Linear(size_h, size_out))

    self.main = nn.Sequential(*layers)

  def forward(self, pts):
    return self.main(pts).squeeze(-1)       # squeeze dimension of the single output value

model = OccNet(size_h=64, n_hidden=4)

# put the model on the GPU to accelerate training
if torch.cuda.is_available():
  model = model.cuda()
else:
  print('Fall back to CPU - GPU usage is recommended, e.g. using Google Colab.')

In [57]:
criterion = nn.BCEWithLogitsLoss(reduction='none')    # binary cross entropy + log  --> same as softargmax
optimizer = torch.optim.Adam(model.parameters())

In [None]:
def train_model(model, train_loader, val_loader, optimizer, criterion, nepochs=15, eval_every=100, out_dir='output'):
  
  liveloss = PlotLosses()   # to plot training progress
  losses = {'loss': None,
            'val_loss': None}

  best = float('inf')
  it = 0
  for epoch in range(nepochs):

    losses['loss'] = []       # initialize emtpy container for training losses
    for pts, occ in train_loader:
      it += 1

      # put data on GPU
      if torch.cuda.is_available():
        pts, occ = pts.cuda(), occ.cuda()

      optimizer.zero_grad()

      scores = model(pts)
      loss = criterion(scores, occ).mean()

      loss.backward()
      optimizer.step()

      losses['loss'].append(loss.item())

      if (it == 1) or (it % eval_every == 0):
        
        with torch.no_grad():
          val_loss = []
          for val_pts, val_occ in val_loader:
            # put data on GPU
            if torch.cuda.is_available():
              val_pts, val_occ = val_pts.cuda(), val_occ.cuda()

            val_scores = model(val_pts)
            val_loss_i = criterion(val_scores, val_occ)
          
            val_loss.extend(val_loss_i)
          val_loss = torch.stack(val_loss).mean().item()
          
          if val_loss < best:     # keep track of best model
            best = val_loss
            torch.save(model.state_dict(), os.path.join(out_dir, 'model_best.pt'))

    # update liveplot with latest values
    losses['val_loss'] = val_loss
    losses['loss'] = np.mean(losses['loss'])     # average over all training losses
    liveloss.update(losses)
    liveloss.send()

train_model(model, train_loader, val_loader, optimizer, criterion, nepochs=25, eval_every=100, out_dir=out_dir)


In [None]:
def make_grid(xmin, xmax, resolution):
  """ Create equidistant points on 3D grid (cube shaped).
  
  Args:
  xmin (float): minimum for x,y,z
  xmax (float): number of hidden layers
  """
  grid_1d = torch.linspace(xmin, xmax, resolution)
  grid_3d = torch.stack(torch.meshgrid(grid_1d, grid_1d, grid_1d), -1)
  return grid_3d.flatten(0, 2)     # return as flattened tensor: RxRxRx3 -> (R^3)x3

resolution = 128          # use 128 grid points in each of the three dimensions -> 128^3 query points
grid = make_grid(-0.5, 0.5, resolution)

# wrap query points in data loader
batch_size = 128
test_loader = torch.utils.data.DataLoader(
    grid, batch_size=128, num_workers=1, shuffle=False, drop_last=False
)


In [None]:
weights_best = torch.load(os.path.join(out_dir, 'model_best.pt'))     # we saved the best model there in the training loop
model.load_state_dict(weights_best)

grid_values = []
with torch.no_grad():
  for pts in tqdm(test_loader, desc='Evaluate occupancy of grid points', position=0, leave=True):
    if torch.cuda.is_available():
        pts = pts.cuda()
    grid_values.extend(model(pts).cpu())

grid_values = torch.stack(grid_values)

In [None]:
grid_occupancies = grid_values > 0.       # convert model scores to classification score
visualize_occupancy(grid, grid_occupancies)

In [None]:
# extract mesh with Marching Cubes
threshold = 0. # because grid values are model scores
assert (grid_values.min() <= threshold) and (grid_values.max() >= threshold), "Threshold is not in range of predicted values"

vertices, faces, _, _ = marching_cubes(grid_values.reshape(resolution, resolution, resolution).numpy(), 
                                                  threshold, 
                                                  spacing=(1/(resolution-1), 1/(resolution-1), 1/(resolution-1)),
                                                  allow_degenerate=False)

# plot mesh
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot_trisurf(vertices[:, 0], vertices[:,1], triangles=faces, Z=vertices[:,2])

ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_zlim(0, 1)

plt.show()