<a href="https://colab.research.google.com/github/JTStephens18/Neural-Fields-As-Learnable-Kernels-Paper-Implementation/blob/main/Neural_Fields_as_Learnable_Kernels_implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Imports

In [None]:
import os
import random
import numpy as np
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from scipy.interpolate import interpn

In [None]:
print(torch.__version__)
!pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html
! pip install torchinfo

In [None]:
# import os
import sys
# import torch
need_pytorch3d=False
try:
    import pytorch3d
except ModuleNotFoundError:
    need_pytorch3d=True
if need_pytorch3d:
    if torch.__version__.startswith("2.1.") and sys.platform.startswith("linux"):
        # We try to install PyTorch3D via a released wheel.
        pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
        version_str="".join([
            f"py3{sys.version_info.minor}_cu",
            torch.version.cuda.replace(".",""),
            f"_pyt{pyt_version_str}"
        ])
        !pip install fvcore iopath
        !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
    else:
        # We try to install PyTorch3D from source.
        !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'

In [None]:
from torch_scatter import scatter_mean, scatter_max
from torchinfo import summary
from pytorch3d.ops import sample_farthest_points
from pytorch3d.loss import chamfer_distance

In [None]:
!unzip "/save/path/augmentedSurfacePoints.zip"
!unzip "/save/path/occupancyPoints_1.zip"
!unzip "/save/path/sampledSurfacePoints.zip"

## Preprocessing

In [None]:
#  Copyright (C) 2012 Daniel Maturana
#  This file is part of binvox-rw-py.
#
#  binvox-rw-py is free software: you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  binvox-rw-py is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License
#  along with binvox-rw-py. If not, see <http://www.gnu.org/licenses/>.
#

"""
Binvox to Numpy and back.


>>> import numpy as np
>>> import binvox_rw
>>> with open('chair.binvox', 'rb') as f:
...     m1 = binvox_rw.read_as_3d_array(f)
...
>>> m1.dims
[32, 32, 32]
>>> m1.scale
41.133000000000003
>>> m1.translate
[0.0, 0.0, 0.0]
>>> with open('chair_out.binvox', 'wb') as f:
...     m1.write(f)
...
>>> with open('chair_out.binvox', 'rb') as f:
...     m2 = binvox_rw.read_as_3d_array(f)
...
>>> m1.dims==m2.dims
True
>>> m1.scale==m2.scale
True
>>> m1.translate==m2.translate
True
>>> np.all(m1.data==m2.data)
True

>>> with open('chair.binvox', 'rb') as f:
...     md = binvox_rw.read_as_3d_array(f)
...
>>> with open('chair.binvox', 'rb') as f:
...     ms = binvox_rw.read_as_coord_array(f)
...
>>> data_ds = binvox_rw.dense_to_sparse(md.data)
>>> data_sd = binvox_rw.sparse_to_dense(ms.data, 32)
>>> np.all(data_sd==md.data)
True
>>> # the ordering of elements returned by numpy.nonzero changes with axis
>>> # ordering, so to compare for equality we first lexically sort the voxels.
>>> np.all(ms.data[:, np.lexsort(ms.data)] == data_ds[:, np.lexsort(data_ds)])
True
"""

import numpy as np

class Voxels(object):
    """ Holds a binvox model.
    data is either a three-dimensional numpy boolean array (dense representation)
    or a two-dimensional numpy float array (coordinate representation).

    dims, translate and scale are the model metadata.

    dims are the voxel dimensions, e.g. [32, 32, 32] for a 32x32x32 model.

    scale and translate relate the voxels to the original model coordinates.

    To translate voxel coordinates i, j, k to original coordinates x, y, z:

    x_n = (i+.5)/dims[0]
    y_n = (j+.5)/dims[1]
    z_n = (k+.5)/dims[2]
    x = scale*x_n + translate[0]
    y = scale*y_n + translate[1]
    z = scale*z_n + translate[2]

    """

    def __init__(self, data, dims, translate, scale, axis_order):
        self.data = data
        self.dims = dims
        self.translate = translate
        self.scale = scale
        assert (axis_order in ('xzy', 'xyz'))
        self.axis_order = axis_order

    def clone(self):
        data = self.data.copy()
        dims = self.dims[:]
        translate = self.translate[:]
        return Voxels(data, dims, translate, self.scale, self.axis_order)

    def write(self, fp):
        write(self, fp)

def read_header(fp):
    """ Read binvox header. Mostly meant for internal use.
    """
    line = fp.readline().strip()
    if not line.startswith(b'#binvox'):
        raise IOError('Not a binvox file')
    dims = list(map(int, fp.readline().strip().split(b' ')[1:]))
    translate = list(map(float, fp.readline().strip().split(b' ')[1:]))
    scale = list(map(float, fp.readline().strip().split(b' ')[1:]))[0]
    line = fp.readline()
    return dims, translate, scale

def read_as_3d_array(fp, fix_coords=True):
    """ Read binary binvox format as array.

    Returns the model with accompanying metadata.

    Voxels are stored in a three-dimensional numpy array, which is simple and
    direct, but may use a lot of memory for large models. (Storage requirements
    are 8*(d^3) bytes, where d is the dimensions of the binvox model. Numpy
    boolean arrays use a byte per element).

    Doesn't do any checks on input except for the '#binvox' line.
    """
    dims, translate, scale = read_header(fp)
    raw_data = np.frombuffer(fp.read(), dtype=np.uint8)
    # if just using reshape() on the raw data:
    # indexing the array as array[i,j,k], the indices map into the
    # coords as:
    # i -> x
    # j -> z
    # k -> y
    # if fix_coords is true, then data is rearranged so that
    # mapping is
    # i -> x
    # j -> y
    # k -> z
    values, counts = raw_data[::2], raw_data[1::2]
    data = np.repeat(values, counts).astype(bool)
    data = data.reshape(dims)
    if fix_coords:
        # xzy to xyz TODO the right thing
        data = np.transpose(data, (0, 2, 1))
        axis_order = 'xyz'
    else:
        axis_order = 'xzy'
    return Voxels(data, dims, translate, scale, axis_order)

def read_as_coord_array(fp, fix_coords=True):
    """ Read binary binvox format as coordinates.

    Returns binvox model with voxels in a "coordinate" representation, i.e.  an
    3 x N array where N is the number of nonzero voxels. Each column
    corresponds to a nonzero voxel and the 3 rows are the (x, z, y) coordinates
    of the voxel.  (The odd ordering is due to the way binvox format lays out
    data).  Note that coordinates refer to the binvox voxels, without any
    scaling or translation.

    Use this to save memory if your model is very sparse (mostly empty).

    Doesn't do any checks on input except for the '#binvox' line.
    """
    dims, translate, scale = read_header(fp)
    raw_data = np.frombuffer(fp.read(), dtype=np.uint8)

    values, counts = raw_data[::2], raw_data[1::2]

    sz = np.prod(dims)
    index, end_index = 0, 0
    end_indices = np.cumsum(counts)
    indices = np.concatenate(([0], end_indices[:-1])).astype(end_indices.dtype)

    values = values.astype(bool)
    indices = indices[values]
    end_indices = end_indices[values]

    nz_voxels = []
    for index, end_index in zip(indices, end_indices):
        nz_voxels.extend(range(index, end_index))
    nz_voxels = np.array(nz_voxels)
    # TODO are these dims correct?
    # according to docs,
    # index = x * wxh + z * width + y; // wxh = width * height = d * d

    x = nz_voxels / (dims[0]*dims[1])
    zwpy = nz_voxels % (dims[0]*dims[1]) # z*w + y
    z = zwpy / dims[0]
    y = zwpy % dims[0]
    if fix_coords:
        data = np.vstack((x, y, z))
        axis_order = 'xyz'
    else:
        data = np.vstack((x, z, y))
        axis_order = 'xzy'

    #return Voxels(data, dims, translate, scale, axis_order)
    return Voxels(np.ascontiguousarray(data), dims, translate, scale, axis_order)

def dense_to_sparse(voxel_data, dtype=int):
    """ From dense representation to sparse (coordinate) representation.
    No coordinate reordering.
    """
    if voxel_data.ndim!=3:
        raise ValueError('voxel_data is wrong shape; should be 3D array.')
    return np.asarray(np.nonzero(voxel_data), dtype)

def sparse_to_dense(voxel_data, dims, dtype=bool):
    if voxel_data.ndim!=2 or voxel_data.shape[0]!=3:
        raise ValueError('voxel_data is wrong shape; should be 3xN array.')
    if np.isscalar(dims):
        dims = [dims]*3
    dims = np.atleast_2d(dims).T
    # truncate to integers
    xyz = voxel_data.astype(np.int)
    # discard voxels that fall outside dims
    valid_ix = ~np.any((xyz < 0) | (xyz >= dims), 0)
    xyz = xyz[:,valid_ix]
    out = np.zeros(dims.flatten(), dtype=dtype)
    out[tuple(xyz)] = True
    return out

#def get_linear_index(x, y, z, dims):
    #""" Assuming xzy order. (y increasing fastest.
    #TODO ensure this is right when dims are not all same
    #"""
    #return x*(dims[1]*dims[2]) + z*dims[1] + y

def write(voxel_model, fp):
    """ Write binary binvox format.

    Note that when saving a model in sparse (coordinate) format, it is first
    converted to dense format.

    Doesn't check if the model is 'sane'.

    """
    if voxel_model.data.ndim==2:
        # TODO avoid conversion to dense
        dense_voxel_data = sparse_to_dense(voxel_model.data, voxel_model.dims)
    else:
        dense_voxel_data = voxel_model.data

    fp.write('#binvox 1\n')
    fp.write('dim '+' '.join(map(str, voxel_model.dims))+'\n')
    fp.write('translate '+' '.join(map(str, voxel_model.translate))+'\n')
    fp.write('scale '+str(voxel_model.scale)+'\n')
    fp.write('data\n')
    if not voxel_model.axis_order in ('xzy', 'xyz'):
        raise ValueError('Unsupported voxel model axis order')

    if voxel_model.axis_order=='xzy':
        voxels_flat = dense_voxel_data.flatten()
    elif voxel_model.axis_order=='xyz':
        voxels_flat = np.transpose(dense_voxel_data, (0, 2, 1)).flatten()

    # keep a sort of state machine for writing run length encoding
    state = voxels_flat[0]
    ctr = 0
    for c in voxels_flat:
        if c==state:
            ctr += 1
            # if ctr hits max, dump
            if ctr==255:
                fp.write(chr(state))
                fp.write(chr(ctr))
                ctr = 0
        else:
            # if switch state, dump
            fp.write(chr(state))
            fp.write(chr(ctr))
            state = c
            ctr = 1
    # flush out remainders
    if ctr > 0:
        fp.write(chr(state))
        fp.write(chr(ctr))


    import doctest
    doctest.testmod()

Sample points from the solid model and surface model

In [None]:
# Assuming point_cloud is a PyTorch tensor of shape [num_points, 3]
# Example: point_cloud = torch.tensor([[1, 2, 3], [4, 5, 6], ...])
def sample_from_grid_space(point_cloud):
  # Initialize an empty grid
  grid_size = 128
  grid = torch.zeros(grid_size, grid_size, grid_size, dtype=torch.bool)

  # Mark points inside the point cloud
  # Convert 3D coordinates to grid indices
  point_cloud_indices = point_cloud.long()
  grid[point_cloud_indices[:, 0], point_cloud_indices[:, 1], point_cloud_indices[:, 2]] = 1

  # Identify points outside the point cloud and sample
  # Find indices where the grid is 0 (i.e., outside the point cloud)
  outside_indices = torch.nonzero(grid == 0, as_tuple=True)

  # Convert indices to 3D coordinates
  outside_points = torch.stack(outside_indices, dim=-1)

  # Optionally, randomly select a subset of points
  desired_sample_size = 1024 # Example: desired sample size
  if outside_points.shape[0] > desired_sample_size:
      sample_indices = torch.randperm(outside_points.shape[0])[:desired_sample_size]
      sample = outside_points[sample_indices]
  else:
      sample = outside_points
  return sample

In [None]:

def normaldefinition_3D_real(void_data, k):
    # void_data is expected to be a PyTorch tensor of shape [16, 1024, 3]
    # Reshape void_data to [16*1024, 3] for distance calculation
    void_data_reshaped = void_data.view(-1, 3)
    dist = torch.cdist(void_data_reshaped, void_data_reshaped)
    closest = dist.argsort(dim=1)

    total_pts = void_data_reshaped.size(0)
    planes = torch.zeros((total_pts, 6))

    for i in range(total_pts):
        # Adjust the indices to work with the reshaped data
        normal_vect, xmn, ymn, zmn, knn_pt_coord = tangentplane_3D_real(closest[i, :k], void_data_reshaped, k)
        planes[i, 0:3] = normal_vect
        planes[i, 3:6] = torch.tensor([xmn, ymn, zmn])

    planes_consist = normalconsistency_3D_real(planes)

    # Reshape planes back to the original shape [batch, 1024, 6]
    # planes_consist = planes_consist.view(void_data.shape[0], 1024, 6)
    # Calculation used when no batch is needed
    planes_consist = planes_consist.view(1024, 6)

    return planes_consist, planes_consist

def tangentplane_3D_real(closest_pt, ellipsoid_data, k):
    knn_pt_id = closest_pt[:k]
    knn_pt_coord = ellipsoid_data[knn_pt_id]

    xmn = knn_pt_coord[:, 0].mean()
    ymn = knn_pt_coord[:, 1].mean()
    zmn = knn_pt_coord[:, 2].mean()

    c = knn_pt_coord - torch.tensor([xmn, ymn, zmn]).to(device)

    cov = torch.mm(c.t(), c) / k
    u, s, vh = torch.svd(cov)
    minevindex = s.argmin()
    normal_vect = u[:, minevindex]

    return normal_vect, xmn, ymn, zmn, knn_pt_coord

def normalconsistency_3D_real(planes):
    nbnormals = planes.size(0)
    planes_consist = torch.zeros((nbnormals, 6))
    planes_consist[:, 3:6] = planes[:, 3:6]

    sensorcentre = torch.tensor([0, 0, 0])

    for i in range(nbnormals):
        p1 = (sensorcentre - planes[i, 3:6]) / torch.norm(sensorcentre - planes[i, 3:6])
        p2 = planes[i, 0:3]

        angle = torch.atan2(torch.norm(torch.cross(p1, p2)), torch.dot(p1, p2))

        if (-np.pi/2 <= angle <= np.pi/2):
            planes_consist[i, 0:3] = -planes[i, 0:3]
        else:
            planes_consist[i, 0:3] = planes[i, 0:3]

    return planes_consist


In [None]:
# Farthest point sampling for augmented points
# X+ = {x_i + n_i * eps}
# X- = {x_i - n_i * eps}
path = "./binvoxSurfaceModels"
from pytorch3d.ops import sample_farthest_points
files = os.listdir(path)
arr = []
max = 0
eps = 1.0
for i in range (len(files)):
  item = f'./{path}/{files[i]}'
  with open(item, 'rb') as f:
    pcItem = read_as_coord_array(f)
    data = torch.from_numpy(pcItem.data.astype(float)).to(device)
    subData = data.permute(1,0).unsqueeze(0).to(device)
    length = torch.full((1,), subData.shape[1]).to(device)
    val = sample_farthest_points(subData, length, 2048)
    normal = normaldefinition_3D_real(subData, 32)
    normals = normal[1][:,:3]
    S = subData.shape[1]
    x_plus = torch.tensor((S, 3))
    x_plus = subData + (eps * normals).to(device)
    x_minus = torch.tensor((S, 3))
    x_minus = subData - (eps * normals).to(device)
    augmentedPoints = torch.cat((x_plus, x_minus), dim=0).to(device)
    dataName = files[i].split(".")[0]
    torch.save(augmentedPoints, f'./augmentedSurfacePoints/{dataName}.pt')
    print(i, " : ", dataName)

In [None]:
!zip -r "./augmentedSurfacePoints.zip" "./augmentedSurfacePoints"
! cp augmentedSurfacePoints.zip /save/path/augmentedSurfacePoints.zip

In [None]:
# Farthest point sampling for the surface of a model
path = "./binvoxSurfaceModels"
from pytorch3d.ops import sample_farthest_points
files = os.listdir(path)
arr = []
max = 0
eps = 1.0
for i in range (len(files)):
  item = f'./{path}/{files[i]}'
  with open(item, 'rb') as f:
    pcItem = read_as_coord_array(f)
    data = torch.from_numpy(pcItem.data.astype(float)).to(device)
    subData = data.permute(1,0).unsqueeze(0).to(device)
    length = torch.full((1,), subData.shape[1]).to(device)
    surfacePoints = sample_farthest_points(subData, length, 2048)
    dataName = files[i].split(".")[0]
    torch.save(surfacePoints, f'./sampledSurfacePoints/{dataName}.pt')
    print(i, " : ", dataName)

In [None]:
!zip -r "./sampledSurfacePoints.zip" "./sampledSurfacePoints"
! cp sampledSurfacePoints.zip /save/path/sampledSurfacePoints.zip

In [None]:
# Farthest point sampling for a solid model
# Includes half points of the model and half points in the empty space outside
path = "./binvoxModels"
from pytorch3d.ops import sample_farthest_points
files = os.listdir(path)
arr = []
max = 0
eps = 1.0
for i in range (len(files)):
  item = f'./{path}/{files[i]}'
  with open(item, 'rb') as f:
    pcItem = read_as_coord_array(f)
    data = torch.from_numpy(pcItem.data.astype(float)).to(device)
    subData = data.permute(1,0).unsqueeze(0).to(device)
    length = torch.full((1,), subData.shape[1]).to(device)
    val = sample_farthest_points(subData, length, 1024)
    empty_points = sample_from_grid_space(subData).to(device)
    occupancyPoints = torch.cat((val[0].squeeze(0), empty_points)).to(device)
    dataName = files[i].split(".")[0]
    torch.save(occupancyPoints, f'./occupanyPoints/{dataName}.pt')
    print(i, " : ", dataName)

In [None]:
!zip -r "./occupancyKernelPoints.zip" "./occupancyKernelPoints"
! cp occupancyKernelPoints.zip /save/path/occupancyKernelPoints.zip

In [None]:
import plotly.graph_objects as go

def plotNormalsAndPoints(tensor):
    # Assuming tensor is of shape [points, 6] where the first 3 columns are normal vectors and the last 3 are points
    # Normalize the vectors to ensure they are of unit length
    normalized_tensor = tensor[:, :3]

    # Separate the XYZ components of the points
    points_x = tensor[:, 3]
    points_y = tensor[:, 4]
    points_z = tensor[:, 5]

    # Create a scatter plot for the points
    fig = go.Figure()

    # Add the points to the plot
    fig.add_trace(go.Scatter3d(x=points_x, y=points_y, z=points_z, mode='markers', marker=dict(size=2, color='blue')))

    # Add the normal vectors as arrows
    for i in range(len(points_x)):
        fig.add_trace(go.Cone(x=[points_x[i]], y=[points_y[i]], z=[points_z[i]], u=[normalized_tensor[i, 0]], v=[normalized_tensor[i, 1]], w=[normalized_tensor[i, 2]], sizemode="absolute", sizeref=1.0, anchor="tail"))

    # Set the layout
    fig.update_layout(scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'))

    # Show the plot
    fig.show()


In [None]:
import plotly.graph_objects as go

def plotPointClouds(input):
  threshold = 0.5
  input = input.permute(1,0)
  input = input.cpu()
  x = input[0]
  y = input[1]
  z = input[2]

  fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z, mode="markers", marker=dict(size=1))])

  fig.update_layout(scene=dict(xaxis_title='X', yaxis_title="Y", zaxis_title="Z"))

  fig.show()

In [None]:
import torch

def map_points_to_grid(points, grid_size, num_features, points_features):
    """
    Map points to a grid and perform average pooling of features within each grid cell.

    Args:
    - points: Tensor of shape (B, N, 3) containing (x, y, z) coordinates of points, where B is the batch size, N is the number of points per batch.
              Also takes an input tensor of shape (B, N, num_features) containing features for each point.
              Point values range from 0-128
    - grid_size: Tuple containing (grid_x, grid_y, grid_z) dimensions of the grid.
    - num_features: Number of features per point.

    Returns:
    - grid_tensor: Tensor of shape (B, 8, 8, 8, num_features) containing the average pooled features within each grid cell.
    """

    B, N, _ = points.size()
    grid_tensor = torch.zeros(B, 8, 8, 8, num_features).to(device)

    # Calculate the size of each grid cell
    cell_size = (128 / grid_size[0], 128 / grid_size[1], 128 / grid_size[2])

    # Iterate through each batch
    for batch_idx in range(B):
        # Create a tensor to accumulate features for each grid cell
        cell_features = torch.zeros(8, 8, 8, num_features).to(device)
        count = torch.zeros(8, 8, 8).to(device)

        # Iterate through each point in the batch
        for point, features in zip(points[batch_idx], points_features[batch_idx]):
            x, y, z = point.tolist()

            # Map point to grid cell
            grid_x = min(int(x / cell_size[0]), grid_size[0] - 1)
            grid_y = min(int(y / cell_size[1]), grid_size[1] - 1)
            grid_z = min(int(z / cell_size[2]), grid_size[2] - 1)

            # Accumulate features for the corresponding grid cell
            cell_features[grid_x, grid_y, grid_z] += features
            count[grid_x, grid_y, grid_z] += 1

        # Avoid division by zero
        count[count == 0] = 1

        # Perform average pooling
        grid_tensor[batch_idx] = cell_features / count.unsqueeze(-1)

    return grid_tensor

# Define grid size
grid_size = (8, 8, 8)

# Map points to the grid and perform average pooling
# grid_tensor = map_points_to_grid(inputTensor, grid_size, out.size(2), out)
# grid_tensor = grid_tensor.permute(0,4,1,2,3)
# print("Grid tensor shape:", grid_tensor.shape)

## Encoder

Based on Convolutional Occupancy Networks: https://github.com/autonomousvision/convolutional_occupancy_networks/tree/master

In [None]:
class ResNetBlock(nn.Module):
  def __init__(self, in_size, out_size, h_size=None):
    super(ResNetBlock, self).__init__()

    if h_size is None:
      h_size = min(in_size, out_size)

    self.in_size = in_size
    self.h_size = h_size
    self.out_size = out_size

    self.fc_0 = nn.Linear(in_size, h_size)
    self.fc_1 = nn.Linear(h_size, out_size)
    self.actfn = nn.ReLU()

    if in_size == out_size:
      self.shortcut = None
    else:
      self.shortcut = nn.Linear(in_size, out_size, bias=False)

    nn.init.zeros_(self.fc_1.weight)

  def forward(self, x):
    net = self.fc_0(self.actfn(x))
    dx = self.fc_1(self.actfn(net))

    if self.shortcut is not None:
      x_s = self.shortcut(x)
    else:
      x_s = x

    return x_s + dx

In [None]:
def normalize_3d_coordinate(p, padding=0.1):
  p_nor = p / (1 + padding + 10e-4) # (-0.5, 0.5)
  p_nor = p_nor + 0.5 # range (0, 1)
  # f there are outliers out of the range
  if p_nor.max() >= 1:
      p_nor[p_nor >= 1] = 1 - 10e-4
  if p_nor.min() < 0:
      p_nor[p_nor < 0] = 0.0
  return p_nor

In [None]:
def coordinate2index(x, reso):
  x = (x * reso).long()
  index = x[:,:, 0] + reso * (x[:,:,1] + reso * x[:,:,2])
  index = index[:, None, :]
  index = index % reso**3
  return index

In [None]:
class PointNetEncoder(nn.Module):
  def __init__(self, in_dim, out_dim, hidden_dim=128, num_blocks=5, grid_resolution=None):
   super(PointNetEncoder, self).__init__()

   self.fc_1 = nn.Linear(in_dim, 2*hidden_dim)
   self.blocks = nn.ModuleList([
       ResNetBlock(2*hidden_dim, hidden_dim) for i in range(num_blocks)
   ])
   self.fc_2 = nn.Linear(hidden_dim, out_dim)
   self.actfn = nn.ReLU()
   self.hidden_dim = hidden_dim
   self.scatter = scatter_max
   self.reso_grid = grid_resolution


  def pool_local(self, xy, index, c):
    bs, fea_dim = c.size(0), c.size(2)
    c_out = 0
    fea = self.scatter(c.permute(0,2,1), index['grid'], dim_size=self.reso_grid**3)
    if(self.scatter == scatter_max):
      fea = fea[0]
    fea = fea.gather(dim=2, index=index['grid'].expand(-1, fea_dim, -1))
    c_out += fea
    return c_out.permute(0, 2, 1)


  def forward(self, x):
    b_size = x.size(0)

    coord = {}
    index = {}
    coord['grid'] = normalize_3d_coordinate(x.clone())
    index['grid'] = coordinate2index(x, self.reso_grid)

    net = self.fc_1(x)

    net = self.blocks[0](net)
    for block in self.blocks[1:]:
      pooled = self.pool_local(coord, index, net)
      net = torch.cat([net, pooled], dim=2)
      net = block(net)

    x = self.fc_2(net)
    return x

In [None]:
PointEncoder = PointNetEncoder(3, 32, grid_resolution=8).to(device)

## UNet

In [None]:
def conv3d(in_channels, out_channels, kernel_size, bias, padding=1):
  return nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias)

In [None]:
def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=1):
  modules = []

  for i, char in enumerate(order):
      if char == 'r':
          modules.append(('ReLU', nn.ReLU(inplace=True)))
      elif char == 'l':
          modules.append(('LeakyReLU', nn.LeakyReLU(negative_slope=0.1, inplace=True)))
      elif char == 'e':
          modules.append(('ELU', nn.ELU(inplace=True)))
      elif char == 'c':
          # add learnable bias only in the absence of batchnorm/groupnorm
          bias = not ('g' in order or 'b' in order)
          modules.append(('conv', conv3d(in_channels, out_channels, kernel_size, bias, padding=padding)))
      elif char == 'g':
          is_before_conv = i < order.index('c')
          if is_before_conv:
              num_channels = in_channels
          else:
              num_channels = out_channels

          # use only one group if the given number of groups is greater than the number of channels
          if num_channels < num_groups:
              num_groups = 1

          assert num_channels % num_groups == 0, f'Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}'
          modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)))
      elif char == 'b':
          is_before_conv = i < order.index('c')
          if is_before_conv:
              modules.append(('batchnorm', nn.BatchNorm3d(in_channels)))
          else:
              modules.append(('batchnorm', nn.BatchNorm3d(out_channels)))
  return modules

In [None]:
class SingleConv(nn.Sequential):
  def __init__(self, in_c, out_c, kernel_size=3, order='crg', num_groups=8, padding=1):
    super(SingleConv, self).__init__()

    for name, module in create_conv(in_c, out_c, kernel_size, order, num_groups, padding=padding):
      self.add_module(name, module)

In [None]:
class DoubleConv(nn.Sequential):
  def __init__(self, in_c, out_c, encoder, kernel_size=3, order='crg', num_groups=8):
    super(DoubleConv, self).__init__()
    if encoder:
      conv1_in_channels = in_c
      conv1_out_channels = out_c // 2
      if conv1_out_channels < in_c:
        conv1_out_channels = in_c
      conv2_in_channels, conv2_out_channels = conv1_out_channels, out_c
    else:
      conv1_in_channels, conv1_out_channels = in_c, out_c
      conv2_in_channels, conv2_out_channels = out_c, out_c

    self.add_module('SingleConv1',
                    SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups))
    self.add_module('SingleConv2',
                    SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups))

In [None]:
class Encoder(nn.Module):
  def __init__(self, in_c, out_c, conv_kernel_size=3, apply_pooling=True, pool_kernel_size=(2,2,2), pool_type="max", basic_module=DoubleConv, conv_layer_order='crg', num_groups=8):
    super(Encoder, self).__init__()
    if apply_pooling:
      self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size)
    else:
      self.pooling = None
    self.basic_module = basic_module(in_c, out_c, encoder=True, kernel_size=conv_kernel_size, order=conv_layer_order, num_groups=num_groups)

  def forward(self, x):
    if self.pooling is not None:
      x = self.pooling(x)
    x = self.basic_module(x)
    return x

In [None]:
class Upsampling(nn.Module):
  def __init__(self, transposed_conv, in_c=None, out_c=None, kernel_size=3, scale_factor=(2,2,2), mode='nearest'):
    super(Upsampling, self).__init__()

    if transposed_conv:
      self.upsample = nn.ConvTranspose3d(in_c, out_c, kernel_size=kernel_size, stride=scale_factor, padding=1)
    else:
      self.upsample = partial(self._interpolate, mode=mode)

  def forward(self, encoder_features, x):
    output_size = encoder_features.size()[2:]
    return self.upsample(x, output_size)

  @staticmethod
  def _interpolate(x, size, mode):
    return F.interpolate(x, size=size, mode=mode)

In [None]:
class Decoder(nn.Module):
  def __init__(self, in_c, out_c, kernel_size=3, scale_factor=(2,2,2), basic_module=DoubleConv, conv_layer_order='crg', num_groups=8, mode='nearest', transposed_conv=False):
    super(Decoder, self).__init__()

    self.upsampling = Upsampling(transposed_conv=transposed_conv, in_c=in_c, out_c=out_c, kernel_size=kernel_size, scale_factor=scale_factor, mode=mode)

    self.joining = partial(self._joining, concat=True)

    self.basic_module = basic_module(in_c, out_c,
                                         encoder=False,
                                         kernel_size=kernel_size,
                                         order=conv_layer_order,
                                         num_groups=num_groups)

  def forward(self, encoder_features, x):
    x = self.upsampling(encoder_features=encoder_features, x=x)
    x = self.joining(encoder_features, x)
    x = self.basic_module(x)
    return x

  @staticmethod
  def _joining(encoder_features, x, concat):
    if concat:
      return torch.cat((encoder_features, x), dim=1)
    else:
      return encoder_features + x

In [None]:
def number_of_features_per_level(init_channel_number, num_levels):
    return [init_channel_number * 2 ** k for k in range(num_levels)]

In [None]:
class Abstract3DUNet(nn.Module):
  def __init__(self, in_c, out_c, final_sigmoid, basic_module, f_maps=128, layer_order='gcr', num_groups=8, num_levels=3, is_segmentation=False, testing=False, **kwargs):
    super(Abstract3DUNet, self).__init__()

    self.testing = testing

    if isinstance(f_maps, int):
      f_maps = number_of_features_per_level(f_maps, num_levels=num_levels)

    encoders = []
    for i, out_feature_num in enumerate(f_maps):
      if i == 0:
        encoder = Encoder(in_c, out_feature_num, apply_pooling=False, basic_module=basic_module, conv_layer_order=layer_order, num_groups=num_groups)
      else:
        encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=basic_module,
                                  conv_layer_order=layer_order, num_groups=num_groups)
      encoders.append(encoder)
    self.encoders = nn.ModuleList(encoders)

    decoders = []
    reversed_f_maps = list(reversed(f_maps))
    for i in range(len(reversed_f_maps)-1):
        if basic_module == DoubleConv:
            in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1]
        else:
            in_feature_num = reversed_f_maps[i]

        out_feature_num = reversed_f_maps[i + 1]
        decoder = Decoder(in_feature_num, out_feature_num, basic_module=basic_module,
                          conv_layer_order=layer_order, num_groups=num_groups)
        decoders.append(decoder)
    self.decoders = nn.ModuleList(decoders)

    self.final_conv = nn.Conv3d(f_maps[0], out_c, 1)
    self.final_activation=None

  def forward(self, x):
    encoders_features = []
    for encoder in self.encoders:
        encoders_features.insert(0, x)
        x = encoder(x)
        # reverse the encoder outputs to be aligned with the decoder

    encoders_features = encoders_features[0:]

    # decoder part
    for decoder, encoder_features in zip(self.decoders, encoders_features):
        # pass the output from the corresponding encoder and the output
        # of the previous decoder
        x = decoder(encoder_features, x)

    x = self.final_conv(x)
    return x, encoders_features

In [None]:
class UNet3D(Abstract3DUNet):
  def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=32, layer_order='gcr',
              num_groups=8, num_levels=3, is_segmentation=False, **kwargs):
    super(UNet3D, self).__init__(in_c=in_channels, out_c=out_channels, final_sigmoid=final_sigmoid,
                                  basic_module=DoubleConv, f_maps=f_maps, layer_order=layer_order,
                                  num_groups=num_groups, num_levels=num_levels, is_segmentation=is_segmentation,
                                  **kwargs)

In [None]:
Unet = UNet3D(32, 32).to(device)

## Model

In [None]:
class Model(nn.Module):
  def __init__(self, PointEncoder, Unet):
    super(Model, self).__init__()

    self.pointEncoder = PointEncoder
    self.Unet = Unet

  def forward(self, x):
    encoded_points = self.pointEncoder(x)
    grid_tensor = map_points_to_grid(x, grid_size, encoded_points.size(2), encoded_points)
    grid_tensor = grid_tensor.permute(0,4,1,2,3)
    feature_grid, ef = self.Unet(grid_tensor)
    expanded_grid = nn.functional.interpolate(feature_grid.detach(), scale_factor=16, mode='trilinear')
    return expanded_grid

In [None]:
model = Model(PointEncoder, Unet).to(device)

In [None]:
learning_rate = 0.0001
optimizer = torch.optim.Adam(model.parameters(), learning_rate)

## Training helper functions

In [None]:
def createGramMatrix(points_i, points_j):
  """
    Input: A set of points_i and points_j [batch_size, num_points, num_values] Ex: [16, 2048, 3] and [16, 5000, 3]
    Output: A matrix of size [batch_size, num_points_i, num_points_j, 2, 3]
            Which represents a set of values at each ij index
  """
  jdx = torch.arange(points_j.shape[1]).unsqueeze(0).repeat(points_i.shape[1], 1)
  idx = torch.arange(points_i.shape[1]).unsqueeze(1).repeat(1, points_j.shape[1])

  pairs = torch.cat((points_i[:, idx, :], points_j[:, jdx, :]), dim=3).reshape(points_i.shape[0], points_i.shape[1], points_j.shape[1], 2, -1)
  return pairs

In [None]:
def trilinear_interpolation(query_points, grid):
    # Extract the coordinates of the eight surrounding vertices
    grid = grid.permute(0,2,3,4,1)
    query_points_floor = query_points.floor().long() - 2
    x0, y0, z0 = query_points_floor[:,:,0], query_points_floor[:,:,1], query_points_floor[:,:,2]
    x1, y1, z1 = x0 + 1, y0 + 1, z0 + 1

    batch_enum = torch.arange(query_points.shape[0]).unsqueeze(1)

    # Extract the values at the eight surrounding vertices
    c000 = grid[batch_enum, x0, y0, z0]
    c001 = grid[batch_enum, x0, y0, z1]
    c010 = grid[batch_enum, x0, y1, z0]
    c011 = grid[batch_enum, x0, y1, z1]
    c100 = grid[batch_enum, x1, y0, z0]
    c101 = grid[batch_enum, x1, y0, z1]
    c110 = grid[batch_enum, x1, y1, z0]
    c111 = grid[batch_enum, x1, y1, z1]

    # Compute the interpolation weights and add 1s to match the last dimension of c000 ... c111
    u = (query_points[:,:,0] - x0.float()).unsqueeze(-1).expand(query_points.shape[0],query_points.shape[1], grid.shape[4])
    v = (query_points[:,:,1] - y0.float()).unsqueeze(-1).expand(query_points.shape[0],query_points.shape[1], grid.shape[4])
    w = (query_points[:,:,2] - z0.float()).unsqueeze(-1).expand(query_points.shape[0],query_points.shape[1], grid.shape[4])

    # Perform trilinear interpolation
    interpolated_value = (1 - u) * (1 - v) * (1 - w) * c000 + \
                         (1 - u) * (1 - v) * w * c001 + \
                         (1 - u) * v * (1 - w) * c010 + \
                         (1 - u) * v * w * c011 + \
                         u * (1 - v) * (1 - w) * c100 + \
                         u * (1 - v) * w * c101 + \
                         u * v * (1 - w) * c110 + \
                         u * v * w * c111
    return interpolated_value

In [None]:
def calculateTheta(x_tilde, x_tilde_prime):
  norm = torch.linalg.norm(x_tilde, dim=-1).unsqueeze(3)
  norm_prime = torch.linalg.norm(x_tilde_prime, dim=-1).unsqueeze(3)
  numerator = torch.linalg.norm(norm_prime * x_tilde - norm * x_tilde_prime, dim=-1)
  denominator = torch.linalg.norm(norm_prime * x_tilde + norm * x_tilde_prime, dim=-1)
  theta = torch.atan2(numerator, denominator)
  return theta

In [None]:
def calculateNeuralSpline(x):
    x_tilde = x[..., 0]  # Extract x_tilde
    x_tilde_prime = x[..., 1]  # Extract x_tilde_prime

    theta = calculateTheta(x_tilde, x_tilde_prime)
    firstTerm = (torch.linalg.norm(x_tilde, dim=-1) * torch.linalg.norm(x_tilde_prime, dim=-1) / np.pi)
    secondTerm = (torch.sin(theta) + 2 * (np.pi - theta) * torch.cos(theta))
    kernelVal = firstTerm * secondTerm
    return kernelVal.squeeze()  # Remove singleton dimensions

In [None]:
def calculateKernel(points_i, points_j, grid):
  # Get features for each point by trilinearly interpolating from output grid
  features_i = trilinear_interpolation(points_i, grid)
  features_j = trilinear_interpolation(points_j, grid)
  # Concat features with points
  concat_points_i = torch.cat((points_i, features_i), dim=2)
  concat_points_j = torch.cat((points_j, features_j), dim=2)
  # Calculate gram matrix
  matrix = createGramMatrix(concat_points_i, concat_points_j)
  # Pass matrix into calculateNeuralSpline
  Kns = calculateNeuralSpline(matrix)
  # Return values
  return Kns

In [None]:
def f_x(alpha, new_points, original_points, grid):
  return alpha * calculateKernel(new_points, original_points, grid)

## Training loop

In [None]:
epochs = 100
step = 0
train_loader_len = len(train_loader)
yVector = torch.ones((b_size, 2048, 1)).to(device)
yVector[0][1024:] = -1
yVolume = torch.ones((b_size, 2048, 1)).to(device)
yVolume[0][1024:] = 0

for epoch in range(epochs):
  print(f'***************** epoch: {epoch} *****************')
  for idx, data in enumerate(train_loader):
    if(idx >= train_loader_len - 1):
      continue
    step += 1
    tempPointsList = []
    tempOccupancyPointsList = []
    tempSurfacePointsList = []
    idx+=1
    for item in range(len(data)):
      augmentedPoints = torch.load(f'./augmentedSurfacePoints/{data[item]}')
      tempPointsList.append(augmentedPoints)
      occupancyPoints = torch.load(f'./occupancyPoints_1/{data[item]}')
      tempOccupancyPointsList.append(occupancyPoints)
      surfacePoints = torch.load(f'./sampledSurfacePoints/{data[item]}')
      tempSurfacePointsList.append(surfacePoints[0])
    inputPoints = torch.stack(tempPointsList, dim=0).to(device)
    occupancyTensor = torch.stack(tempOccupancyPointsList, dim=0).to(device)
    surfaceTensor = torch.stack(tempSurfacePointsList, dim=0).to(device)

    expanded_grid = model(inputPoints)
    points_i = inputPoints.clone()
    points_j = inputPoints.clone()
    features_i = trilinear_interpolation(points_i, expanded_grid)
    features_j = trilinear_interpolation(points_j, expanded_grid)
    concat_points_i = torch.cat((points_i, features_i), dim=2)
    concat_points_j = torch.cat((points_j, features_j), dim=2)
    matrix = createGramMatrix(concat_points_i, concat_points_j).detach()
    Kns = calculateNeuralSpline(matrix)

    lambdaVal = 0.0001
    identityMatrix = torch.zeros(Kns.shape[0], Kns.shape[1], Kns.shape[2]).to(device)
    identityMatrix[:, torch.arange(Kns.shape[1]), torch.arange(Kns.shape[2])] = 1

    kernelMatrix = torch.linalg.inv((Kns + (lambdaVal * identityMatrix)))
    alpha = torch.matmul(kernelMatrix, yVector).to(device)

    occupancyPred = f_x(alpha, occupancyTensor, inputPoints, expanded_grid)
    surfacePred = f_x(alpha, surfaceTensor, inputPoints, expanded_grid)
    bceLoss = nn.BCEWithLogitsLoss()
    occupancyLoss = bceLoss(occupancyPred, yVolume)
    l1_lambda = 0.001
    surfaceLoss = l1_lambda * torch.sum(torch.abs(surfacePred), dim=1)
    loss = occupancyLoss + surfaceLoss
    print("Loss", loss.detach().item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()