In [1]:
import torch
import torch.nn as nn

class VoxelToSDFUNet(nn.Module):
    def __init__(self):
        super(VoxelToSDFUNet, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv3d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=2, stride=2)
        )

        # Latent
        self.latent = nn.Sequential(
            nn.Conv3d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=2, stride=2)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv3d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose3d(32, 1, kernel_size=2, stride=2)
        )

    def forward(self, x):
        # Encoder
        x = self.encoder(x)
        # Latent
        x = self.latent(x)
        # Decoder with skip connections
        x = self.decoder(x)

        return x

# Instantiate the model
model = VoxelToSDFUNet()

# Print the model architecture
print(model)


VoxelToSDFUNet(
  (encoder): Sequential(
    (0): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): ReLU(inplace=True)
    (2): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (latent): Sequential(
    (0): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): ReLU(inplace=True)
    (2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (decoder): Sequential(
    (0): ConvTranspose3d(128, 64, kernel_size=(2, 2, 2), stride=(2, 2, 2))
    (1): ReLU(inplace=True)
    (2): Conv3d(64, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (3): ReLU(inplace=True)
    (4): ConvTranspose3d(32, 1, kernel_size=(2, 2, 2), stride=(2, 2, 

In [2]:
!ls

Untitled.ipynb	example_sdf.npy  example_voxels.npy  exercise1	o.npy


In [3]:
import numpy as np

voxel = torch.tensor(np.load("example_voxels.npy"), dtype=torch.float)
sdf = torch.tensor(np.load("example_sdf.npy"), dtype=torch.float)

In [4]:
sdf.unsqueeze(0).unsqueeze(0).shape

torch.Size([1, 1, 32, 32, 32])

In [5]:
voxel.sum()

tensor(5400.)

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim

# Assume you have a single input voxel grid and its corresponding SDF as ground truth
input_voxel_grid = voxel.unsqueeze(0).unsqueeze(0)
ground_truth_sdf = sdf.unsqueeze(0).unsqueeze(0)

# Instantiate the model
model = VoxelToSDFUNet()

# Loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Number of training epochs
num_epochs = 100

# Training loop
for epoch in range(num_epochs):
    # Forward pass
    output_sdf = model(input_voxel_grid)
    
    loss = criterion(output_sdf, ground_truth_sdf)

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Print training statistics
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')

Epoch [1/100], Loss: 0.16143792867660522
Epoch [2/100], Loss: 0.1493491232395172
Epoch [3/100], Loss: 0.14915072917938232
Epoch [4/100], Loss: 0.1369166523218155
Epoch [5/100], Loss: 0.13153508305549622
Epoch [6/100], Loss: 0.11623180657625198
Epoch [7/100], Loss: 0.10344547778367996
Epoch [8/100], Loss: 0.09564467519521713
Epoch [9/100], Loss: 0.08972221612930298
Epoch [10/100], Loss: 0.09096783399581909
Epoch [11/100], Loss: 0.08692220598459244
Epoch [12/100], Loss: 0.08332204073667526
Epoch [13/100], Loss: 0.07950714230537415
Epoch [14/100], Loss: 0.07458368688821793
Epoch [15/100], Loss: 0.07345329225063324
Epoch [16/100], Loss: 0.07160426676273346
Epoch [17/100], Loss: 0.07149144262075424
Epoch [18/100], Loss: 0.06811050325632095
Epoch [19/100], Loss: 0.06628699600696564
Epoch [20/100], Loss: 0.06268570572137833
Epoch [21/100], Loss: 0.061969030648469925
Epoch [22/100], Loss: 0.05938291922211647
Epoch [23/100], Loss: 0.05836907774209976
Epoch [24/100], Loss: 0.057421013712882996
E

In [None]:
overfit_sample = model(voxel.unsqueeze(0).unsqueeze(0))

In [None]:
overfit_sample = overfit_sample.detach().numpy()
np.save('overfit_sample.npy', overfit_sample)