In [1]:
import numpy as np
import torch
import mcubes
from torch import nn
from scipy import spatial
from matplotlib import pyplot as plt
import torch.nn.functional as F
import torchvision.transforms as transforms 

from PIL import Image 

In [None]:
class ReLuNet(nn.Module):
    def __init__(self, ninputchannels):
        super(ReLuNet, self).__init__()
        #MLP with 8 layers
        self.fc1 = nn.Linear(ninputchannels, 3)
        self.fc2 = nn.Linear(3, 512)
        self.fc3 = nn.Linear(512, 512)
        self.fc4 = nn.Linear(512,509)
        self.fc5 = nn.Linear(509,512)
        self.fc6 = nn.Linear(512,512)
        self.fc7 = nn.Linear(512,512)
        self.fc8 = nn.Linear(512,1)

        self.skip = nn.Linear(ninputchannels, 509)

    def forward(self,x):
        #TODO
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = torch.concat([x, F.relu(self.skip(x))], dim = 1)
        x = F.relu(self.fc5(x))
        x = F.relu(self.fc6(x))
        x = F.relu(self.fc7(x))
        x = F.tanh(self.fc8(x))
        return x

In [None]:
def evaluate_loss(relunet, pts_gt, sdf_gt, device, lpc, batch_size=2000, delta = 0.1, pc_batch_size=2000):
    #pts_random = torch.rand((batch_size, 3), device = device)*2-1
    indices = torch.randint(pts_gt.shape[0], (batch_size,))
    pts_gt_sub = pts_gt[indices,:]

    #TODO: compute the result
    output = relunet(pts_gt_sub).to(device)

    # compute and store the losses
    criterion = nn.MSELoss()

    def clamp(x):
        return torch.clamp(x, -delta, delta)
    
    loss = sum(abs(clamp(output) - clamp(sdf_gt[indices].view(-1,1))))

    # append all the losses
    lpc.append(float(loss.item()))
  
    return loss

In [None]:
def main_shape() :
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    p = np.loadtxt('armadillo_sub.xyz')
    
    #compute the enclosing grid
    maxx = np.max(p[:,0])
    minx = np.min(p[:,0])
    maxy = np.max(p[:,1])
    miny = np.min(p[:,1])
    maxz = np.max(p[:,2])
    minz = np.min(p[:,2])
    
    #normalize the shape 
    maxdim = np.max((maxx-minx, maxy-miny, maxz-minz))
    
    p[:,0:3] = 1.9999*(p[:,0:3] - [minx,miny,minz])/maxdim-0.99999

    #preparing gt points:
    #TODO: using a kdtree find the groundtruth distance from the points to the shape
    gtsdf = spatial.KDTree(p[:,0:3])
    
    geomnet = ReLuNet(3)
    geomnet.to(device)

    gtpoints = torch.from_numpy(gtp).float().to(device)
    gtsdf = torch.from_numpy(sdf).float().to(device)

    lpc = []

    optim = torch.optim.Adam(params = geomnet.parameters(), lr=1e-5)

    nepochs=10000
    
    for epoch in range(nepochs):
        #TODO do one step training
        optim.zero_grad()
        output = geomnet(gtpoints)
        loss = evaluate_loss(geomnet, gtpoints, gtsdf, device, lpc)
        loss.backward()
        lpc.append(loss.item())
        optim.step()
        if epoch % 100 == 0:
            print(f"Epoch {epoch}/{nepochs} - loss : {loss.item()}")

    #use marching cubes to extract the shape
    

    # display the result
    plt.figure(figsize=(6,4))
    plt.yscale('log')
    plt.plot(lpc, label = 'Point cloud loss ({:.2f})'.format(lpc[-1]))
    plt.xlabel("Epochs")
    plt.legend()
    plt.savefig("loss.pdf")
    plt.close()

In [None]:
main_shape()

# Some explanation

**DeepSDF** (Deep Signed Distance Function) is a neural network-based approach for representing and reconstructing 3D shapes. It leverages the concept of Signed Distance Functions (SDFs) to model the surface of 3D objects. DeepSDF was introduced by Park et al. in their 2019 paper "DeepSDF: Learning Continuous Signed Distance Functions for Shape Representation."

### Key Concepts

1. **Signed Distance Function (SDF)**:
   - An SDF is a scalar field that represents the distance from any point in space to the closest surface of an object.
   - The distance is positive if the point is outside the object, negative if the point is inside, and zero if the point is on the surface.

2. **Neural Network Representation**:
   - DeepSDF uses a neural network to parameterize the SDF.
   - The network takes a 3D coordinate as input and outputs the signed distance value for that coordinate.

3. **Continuous Representation**:
   - Unlike voxel grids or point clouds, DeepSDF provides a continuous representation of 3D shapes, allowing for high-resolution and smooth surfaces.

### How DeepSDF Works

1. **Network Architecture**:
   - The DeepSDF network is typically a multi-layer perceptron (MLP) with several fully connected layers.
   - The input to the network is a 3D coordinate (x, y, z), and the output is the signed distance value for that coordinate.

2. **Training**:
   - The network is trained using a dataset of 3D shapes, where each shape is represented by a set of 3D coordinates and their corresponding signed distance values.
   - The loss function is typically the mean squared error (MSE) between the predicted signed distance values and the ground truth values.

3. **Shape Representation**:
   - Once trained, the network can represent a variety of 3D shapes by encoding them as continuous SDFs.
   - The network can be queried at any 3D coordinate to obtain the signed distance value, allowing for high-resolution surface reconstruction.

4. **Latent Code**:
   - DeepSDF can also incorporate a latent code to represent different shapes in a latent space.
   - The latent code is concatenated with the 3D coordinates and fed into the network, enabling the network to learn a continuous family of shapes.

### Example Code

Here is a simplified example of how to implement a DeepSDF-like network in PyTorch:

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DeepSDF(nn.Module):
    def __init__(self, latent_dim=256):
        super(DeepSDF, self).__init__()
        self.fc1 = nn.Linear(3 + latent_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, 512)
        self.fc4 = nn.Linear(512, 512)
        self.fc5 = nn.Linear(512, 1)
        self.latent_dim = latent_dim

    def forward(self, x, latent_code):
        x = torch.cat([x, latent_code], dim=-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        sdf = self.fc5(x)
        return sdf

# Example usage
latent_dim = 256
model = DeepSDF(latent_dim=latent_dim)

# Example input: 3D coordinates and latent code
coords = torch.randn(10, 3)  # 10 points in 3D space
latent_code = torch.randn(10, latent_dim)  # Corresponding latent code

# Forward pass
sdf_values = model(coords, latent_code)
print(sdf_values.shape)  # Should be (10, 1)

### Key Points

- **DeepSDF**: A neural network-based approach for representing and reconstructing 3D shapes using Signed Distance Functions.
- **SDF**: Represents the distance from any point in space to the closest surface of an object, with the sign indicating whether the point is inside or outside the object.
- **Continuous Representation**: Provides a high-resolution and smooth representation of 3D shapes.
- **Latent Code**: Allows the network to learn a continuous family of shapes by incorporating a latent code.