In [5]:
## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim


In [6]:
class Encoder(nn.Module):

    def __init__(self,
                 num_input_channels : int,
                 base_channel_size : int,
                 latent_dim : int,
                 act_fn : object = nn.GELU):
        """
        Inputs:
            - num_input_channels : Number of input channels of the image. For CIFAR, this parameter is 3
            - base_channel_size : Number of channels we use in the first convolutional layers. Deeper layers might use a duplicate of it.
            - latent_dim : Dimensionality of latent representation z
            - act_fn : Activation function used throughout the encoder network
        """
        super().__init__()
        c_hid = base_channel_size
        self.net = nn.Sequential(
            nn.Conv3d(num_input_channels, c_hid, kernel_size=3, padding=1, stride=2), # 32x32x32 => 16x16x16
            act_fn(),
            nn.Conv3d(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.Conv3d(c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 16x16x16 => 8x8x8
            act_fn(),
            nn.Conv3d(2*c_hid, 2*c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.Conv3d(2*c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 8x8x8 => 4x4x4
            act_fn(),
            nn.Flatten(), # Image grid to single feature vector
            nn.Linear(2*8*8*8*c_hid, latent_dim)
        )

    def forward(self, x):
        return self.net(x)


In [7]:
layer1 = Encoder(1, 3, 10)

In [8]:
data = torch.randn(16, 1, 64, 64, 64)

layer1.forward(data)

tensor([[-0.0044, -0.0051, -0.0221, -0.0036, -0.0023,  0.0142,  0.0048,  0.0170,
          0.0242,  0.0036],
        [-0.0013, -0.0046, -0.0215, -0.0026, -0.0013,  0.0148,  0.0043,  0.0193,
          0.0231,  0.0064],
        [-0.0037, -0.0044, -0.0230, -0.0023, -0.0020,  0.0141,  0.0077,  0.0162,
          0.0231,  0.0051],
        [-0.0025, -0.0043, -0.0240, -0.0008, -0.0019,  0.0120,  0.0031,  0.0175,
          0.0256,  0.0046],
        [-0.0037, -0.0042, -0.0238, -0.0024, -0.0021,  0.0123,  0.0046,  0.0195,
          0.0231,  0.0047],
        [-0.0033, -0.0020, -0.0232, -0.0033, -0.0020,  0.0133,  0.0046,  0.0166,
          0.0256,  0.0042],
        [-0.0037, -0.0035, -0.0220, -0.0020, -0.0039,  0.0154,  0.0062,  0.0192,
          0.0229,  0.0047],
        [-0.0025, -0.0046, -0.0228, -0.0030, -0.0041,  0.0133,  0.0058,  0.0172,
          0.0227,  0.0046],
        [-0.0024, -0.0022, -0.0209, -0.0039, -0.0017,  0.0133,  0.0043,  0.0171,
          0.0238,  0.0069],
        [-0.0045, -

In [9]:

class Decoder(nn.Module):
    def __init__(
        self,
        latent_size,
        dims,
        dropout=None,
        dropout_prob=0.0,
        norm_layers=(),
        latent_in=(),
        weight_norm=False,
        xyz_in_all=None,
        use_tanh=False,
        latent_dropout=False,
    ):
        super(Decoder, self).__init__()

        def make_sequence():
            return []

        dims = [latent_size + 3] + dims + [1]

        self.num_layers = len(dims)
        self.norm_layers = norm_layers
        self.latent_in = latent_in
        self.latent_dropout = latent_dropout
        if self.latent_dropout:
            self.lat_dp = nn.Dropout(0.2)

        self.xyz_in_all = xyz_in_all
        self.weight_norm = weight_norm

        for layer in range(0, self.num_layers - 1):
            if layer + 1 in latent_in:
                out_dim = dims[layer + 1] - dims[0]
            else:
                out_dim = dims[layer + 1]
                if self.xyz_in_all and layer != self.num_layers - 2:
                    out_dim -= 3

            if weight_norm and layer in self.norm_layers:
                setattr(
                    self,
                    "lin" + str(layer),
                    nn.utils.weight_norm(nn.Linear(dims[layer], out_dim)),
                )
            else:
                setattr(self, "lin" + str(layer), nn.Linear(dims[layer], out_dim))

            if (
                (not weight_norm)
                and self.norm_layers is not None
                and layer in self.norm_layers
            ):
                setattr(self, "bn" + str(layer), nn.LayerNorm(out_dim))

        self.use_tanh = use_tanh
        if use_tanh:
            self.tanh = nn.Tanh()
        self.relu = nn.ReLU()

        self.dropout_prob = dropout_prob
        self.dropout = dropout
        self.th = nn.Tanh()
    
    # input: N x (L+3)
    def forward(self, input):
        xyz = input[:, -3:]

        if input.shape[1] > 3 and self.latent_dropout:
            latent_vecs = input[:, :-3]
            latent_vecs = F.dropout(latent_vecs, p=0.2, training=self.training)
            x = torch.cat([latent_vecs, xyz], 1)
        else:
            x = input

        for layer in range(0, self.num_layers - 1):
            lin = getattr(self, "lin" + str(layer))
            if layer in self.latent_in:
                x = torch.cat([x, input], 1)
            elif layer != 0 and self.xyz_in_all:
                x = torch.cat([x, xyz], 1)
            x = lin(x)
            # last layer Tanh
            if layer == self.num_layers - 2 and self.use_tanh:
                x = self.tanh(x)
            if layer < self.num_layers - 2:
                if (
                    self.norm_layers is not None
                    and layer in self.norm_layers
                    and not self.weight_norm
                ):
                    bn = getattr(self, "bn" + str(layer))
                    x = bn(x)
                x = self.relu(x)
                if self.dropout is not None and layer in self.dropout:
                    x = F.dropout(x, p=self.dropout_prob, training=self.training)

        if hasattr(self, "th"):
            x = self.th(x)

        return x

In [10]:
decode_layer = Decoder(16, [32,64,64,128])

torch.randn((1,19)).shape

torch.Size([1, 19])

In [11]:
decode_layer(torch.randn((1,19)))

tensor([[0.0643]], grad_fn=<TanhBackward0>)

In [12]:
torch.randn((1,19))[:,-3:]

tensor([[-1.1312,  0.5227, -0.1918]])

In [13]:
decode_layer.parameters()

<generator object Module.parameters at 0x7fc7c94fcc10>

In [14]:
enc = Encoder(1, 3, 8)  # num input channels, base channel size, latent dim
dec = Decoder(8, [16,64,64,128])

N = 4

x = torch.randn(N, 1, 64, 64, 64)

z = enc.forward(x)

xyz = torch.randn(N, 3) + 10
z_hat = torch.cat((z,xyz), dim=1)

sdf = dec.forward(z_hat)

In [15]:

def loss(x, x_hat, delta=0.01):
    return torch.sum(x - x_hat)

In [16]:
torch.sum(torch.tensor([1,2,3]))

tensor(6)

In [17]:
from datasets import VoxelSDFDataset

data = VoxelSDFDataset()
print(len(data))

training_set, validation_set = torch.utils.data.random_split(data, [400, len(data) - 400])

training_loader = torch.utils.data.DataLoader(training_set, batch_size=8, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=8, shuffle=False)

515000


In [30]:
VoxelEncoder = Encoder(1, 3, 8)
VoxelDecoder = Decoder(8, [16,32,32,64])

encoder_optimizer = optim.Adam(VoxelEncoder.parameters(), lr=0.0001)
decoder_optimizer = optim.Adam(VoxelDecoder.parameters(), lr=0.0001)

def train_one_epoch(epoch_index, clamp_delta):
    for i, data in enumerate(training_loader):
        
        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()
        voxel_data, [grid, near, surf] = data
        
        loss = 0
        
        for i, voxel in enumerate(voxel_data):
            vox_obj = voxel[None, None,:,:,:].float() # [1, 1, 64, 64, 64]
            latents = VoxelEncoder(vox_obj)
            
            def computeLoss(points, sdfs):
                l = 0
                for i in range(len(points)):
                    p = points[None, i,]
                    sdf = sdfs[i,]
                    z = torch.cat((latents, p), dim=1)
                    sdf_hat = VoxelDecoder(z.float())

                    l += abs(torch.clamp(sdf_hat - sdf, min=-clamp_delta, max=clamp_delta))
                return l

            loss += computeLoss(grid[0][i,], grid[1][i,])
            loss += computeLoss(near[0][i,], near[1][i,])
            loss += computeLoss(surf[0][i,], surf[1][i,])

        print(f'LOSS: {loss}')

        loss.backward()
        
        encoder_optimizer.step()
        decoder_optimizer.step()



In [31]:
# Initializing in a separate cell so we can easily add more epochs to the same run
epoch_number = 0

EPOCHS = 5

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    VoxelEncoder.train(True)
    VoxelDecoder.train(True)
    avg_loss = train_one_epoch(epoch_number, 0.25)

    # We don't need gradients on to do reporting
    VoxelEncoder.train(False)
    VoxelDecoder.train(False)

EPOCH 1:
LOSS: tensor([[292.5618]], grad_fn=<AddBackward0>)
LOSS: tensor([[285.5284]], grad_fn=<AddBackward0>)
LOSS: tensor([[288.6871]], grad_fn=<AddBackward0>)
LOSS: tensor([[285.1157]], grad_fn=<AddBackward0>)
LOSS: tensor([[289.7173]], grad_fn=<AddBackward0>)
LOSS: tensor([[286.1173]], grad_fn=<AddBackward0>)
LOSS: tensor([[280.6072]], grad_fn=<AddBackward0>)
LOSS: tensor([[285.1335]], grad_fn=<AddBackward0>)
LOSS: tensor([[288.5634]], grad_fn=<AddBackward0>)
LOSS: tensor([[281.9569]], grad_fn=<AddBackward0>)
LOSS: tensor([[280.0135]], grad_fn=<AddBackward0>)
LOSS: tensor([[278.1961]], grad_fn=<AddBackward0>)
LOSS: tensor([[276.2312]], grad_fn=<AddBackward0>)
LOSS: tensor([[281.4836]], grad_fn=<AddBackward0>)
LOSS: tensor([[273.8387]], grad_fn=<AddBackward0>)
LOSS: tensor([[273.7753]], grad_fn=<AddBackward0>)
LOSS: tensor([[274.9701]], grad_fn=<AddBackward0>)
LOSS: tensor([[266.0597]], grad_fn=<AddBackward0>)
LOSS: tensor([[264.1991]], grad_fn=<AddBackward0>)
LOSS: tensor([[265.726

KeyboardInterrupt: 