# Learning a Probabilistic Latent Space of Object Shapes via 3D Generative-Adversarial Modeling

In [1]:
import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc

%matplotlib inline
%config InlineBackend.figure_format='retina'
sns.set(style="whitegrid", palette="muted", font_scale=1.2)
rcParams["figure.figsize"] = 16,10

import torch
from torch import nn, optim
import torch.nn.functional as F

import numpy as np
import wandb
from typing import Optional, List, Tuple, Dict
from torchsummary import summary

%load_ext watermark
%watermark -v -p torch,wandb,matplotlib,seaborn,numpy,pylab

Python implementation: CPython
Python version       : 3.10.8
IPython version      : 8.9.0

torch     : 2.0.0.dev20230208
wandb     : 0.13.10
matplotlib: 3.6.3
seaborn   : 0.12.2
numpy     : 1.24.1
pylab     : unknown



## Network Structure

### Image Encoder

In [2]:
class VAE3DGAN(nn.Module):
    def __init__(
        self, 
        in_channels: int = 3, 
        channels: List[int] = [64, 128, 256, 512, 400],
        kernel_sizes: List[int] = [11, 5, 5, 5, 8],
        strides: List[int] = [4, 2, 2, 2, 1]
    ) -> None:
        super(VAE3DGAN, self).__init__()
        self.in_channels = in_channels
        layers = []
        
        for ix in range(len(channels)):
            layers.append(
                nn.Sequential(
                    nn.Conv2d(
                        in_channels=in_channels, 
                        out_channels=channels[ix],
                        kernel_size=kernel_sizes[ix], 
                        stride=strides[ix], 
                        padding=1
                    ), 
                    nn.BatchNorm2d(channels[ix]),
                    nn.ReLU()
                )
            )
            in_channels = channels[ix]
        
        self.net = nn.Sequential(*layers)
        
    def sample_normal(self, std: torch.Tensor) -> torch.Tensor:
        sampler = torch.distributions.Normal(loc=0, scale=1)
        return sampler.sample(std.shape)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        latent = self.net(x)
        mu, std = latent[:, :200,], latent[:, 200:]
        z = self.sample_normal(std)
        latent = mu + z * std
        return latent

In [3]:
vae = VAE3DGAN()
vae

VAE3DGAN(
  (net): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (2): Sequential(
      (0): Conv2d(128, 256, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (3): Sequential(
      (0): Conv2d(256, 512, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (4): Sequential(
      (0): Conv2d(512, 400, kernel_size=(8, 8), stride=(1, 1), padding=(1, 1))
      (1

In [4]:
summary(vae, input_size=(3, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 62, 62]          23,296
       BatchNorm2d-2           [-1, 64, 62, 62]             128
              ReLU-3           [-1, 64, 62, 62]               0
            Conv2d-4          [-1, 128, 30, 30]         204,928
       BatchNorm2d-5          [-1, 128, 30, 30]             256
              ReLU-6          [-1, 128, 30, 30]               0
            Conv2d-7          [-1, 256, 14, 14]         819,456
       BatchNorm2d-8          [-1, 256, 14, 14]             512
              ReLU-9          [-1, 256, 14, 14]               0
           Conv2d-10            [-1, 512, 6, 6]       3,277,312
      BatchNorm2d-11            [-1, 512, 6, 6]           1,024
             ReLU-12            [-1, 512, 6, 6]               0
           Conv2d-13            [-1, 400, 1, 1]      13,107,600
      BatchNorm2d-14            [-1, 40

### Generator

In [5]:
class Generator(nn.Module):
    def __init__(
        self, 
        in_channels: int = 200,
        channels: List[int] = [512, 256, 128, 64, 1],
        kernel_sizes: List[int] = [4, 4, 4, 4, 4],
        strides: List[int] = [1, 2, 2, 2, 2], 
        paddings: List[int] = [0, 1, 1, 1, 1]
    ) -> None:
        super(Generator, self).__init__()
        layers = []
        
        for ix in range(len(channels)):
            layers.append(
                nn.Sequential(
                    nn.ConvTranspose3d(
                        in_channels=in_channels, 
                        out_channels=channels[ix],
                        stride=strides[ix],
                        kernel_size=kernel_sizes[ix], 
                        padding=paddings[ix]
                    ), 
                    nn.BatchNorm3d(channels[ix]), 
                    nn.ReLU()
                )
            )
            in_channels = channels[ix]
        layers.append(nn.Sigmoid())
        self.net = nn.Sequential(*layers)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

In [6]:
gen = Generator()
gen

Generator(
  (net): Sequential(
    (0): Sequential(
      (0): ConvTranspose3d(200, 512, kernel_size=(4, 4, 4), stride=(1, 1, 1))
      (1): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): ConvTranspose3d(512, 256, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
      (1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (2): Sequential(
      (0): ConvTranspose3d(256, 128, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
      (1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (3): Sequential(
      (0): ConvTranspose3d(128, 64, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
      (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (4): Sequential(
      (0): ConvTranspose3d(64, 1

In [7]:
summary(gen, input_size=(200, 1, 1, 1))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
   ConvTranspose3d-1         [-1, 512, 4, 4, 4]       6,554,112
       BatchNorm3d-2         [-1, 512, 4, 4, 4]           1,024
              ReLU-3         [-1, 512, 4, 4, 4]               0
   ConvTranspose3d-4         [-1, 256, 8, 8, 8]       8,388,864
       BatchNorm3d-5         [-1, 256, 8, 8, 8]             512
              ReLU-6         [-1, 256, 8, 8, 8]               0
   ConvTranspose3d-7      [-1, 128, 16, 16, 16]       2,097,280
       BatchNorm3d-8      [-1, 128, 16, 16, 16]             256
              ReLU-9      [-1, 128, 16, 16, 16]               0
  ConvTranspose3d-10       [-1, 64, 32, 32, 32]         524,352
      BatchNorm3d-11       [-1, 64, 32, 32, 32]             128
             ReLU-12       [-1, 64, 32, 32, 32]               0
  ConvTranspose3d-13        [-1, 1, 64, 64, 64]           4,097
      BatchNorm3d-14        [-1, 1, 64,

### Discriminator

In [8]:
class Discriminator(nn.Module):
    def __init__(
        self,
        in_channels: int = 1, 
        channels: List[int] = [64, 128, 256, 512, 1],
        kernel_sizes: List[int] = [4, 4, 4, 4, 4],
        strides: List[int] = [4, 2, 2, 2, 1],
    ) -> None:
        super(Discriminator, self).__init__()
        layers = []
        
        for ix in range(len(channels)):
            layers.append(
                nn.Sequential(
                    nn.Conv3d(
                        in_channels=in_channels, 
                        out_channels=channels[ix], 
                        kernel_size=kernel_sizes[ix],
                        stride=strides[ix], 
                        padding=1
                    ), 
                    nn.BatchNorm3d(channels[ix]),
                    nn.LeakyReLU(0.2)
                )
            )
            in_channels = channels[ix]
        
        self.net = nn.Sequential(*layers)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

In [9]:
disc = Discriminator()
disc

Discriminator(
  (net): Sequential(
    (0): Sequential(
      (0): Conv3d(1, 64, kernel_size=(4, 4, 4), stride=(4, 4, 4), padding=(1, 1, 1))
      (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (1): Sequential(
      (0): Conv3d(64, 128, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
      (1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (2): Sequential(
      (0): Conv3d(128, 256, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
      (1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv3d(256, 512, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
      (1): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(neg

In [10]:
summary(disc, input_size=(1, 64, 64, 64))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1       [-1, 64, 16, 16, 16]           4,160
       BatchNorm3d-2       [-1, 64, 16, 16, 16]             128
         LeakyReLU-3       [-1, 64, 16, 16, 16]               0
            Conv3d-4         [-1, 128, 8, 8, 8]         524,416
       BatchNorm3d-5         [-1, 128, 8, 8, 8]             256
         LeakyReLU-6         [-1, 128, 8, 8, 8]               0
            Conv3d-7         [-1, 256, 4, 4, 4]       2,097,408
       BatchNorm3d-8         [-1, 256, 4, 4, 4]             512
         LeakyReLU-9         [-1, 256, 4, 4, 4]               0
           Conv3d-10         [-1, 512, 2, 2, 2]       8,389,120
      BatchNorm3d-11         [-1, 512, 2, 2, 2]           1,024
        LeakyReLU-12         [-1, 512, 2, 2, 2]               0
           Conv3d-13           [-1, 1, 1, 1, 1]          32,769
      BatchNorm3d-14           [-1, 1, 