# Learning a Probabilistic Latent Space of Object Shapes via 3D Generative-Adversarial Modeling

In [1]:
import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc

%matplotlib inline
%config InlineBackend.figure_format='retina'
sns.set(style="whitegrid", palette="muted", font_scale=1.2)
# rcParams["figure.figsize"] = 16,10

import torch
from torch import nn, optim
from torch.optim import lr_scheduler
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchsummary import summary
import albumentations as A

from sklearn.model_selection import KFold, StratifiedKFold

import numpy as np
import pandas as pd
import wandb
import scipy.io as io

from tqdm import tqdm
from typing import Optional, List, Tuple, Dict
import os
from glob import glob
import random
import copy

%load_ext watermark
%watermark -v -p torch,wandb,matplotlib,seaborn,numpy,pandas,albumentations,scipy,sklearn

Python implementation: CPython
Python version       : 3.10.8
IPython version      : 8.9.0

torch         : 2.0.0.dev20230208
wandb         : 0.13.10
matplotlib    : 3.6.3
seaborn       : 0.12.2
numpy         : 1.24.1
pandas        : 1.5.3
albumentations: 1.3.0
scipy         : 1.10.0
sklearn       : 1.2.1



## Model Architectures

### Image Encoder

In [2]:
class ImageEncoder(nn.Module):
    def __init__(
        self, 
        in_channels: int = 3, 
        channels: List[int] = [64, 128, 256, 512, 400],
        kernel_sizes: List[int] = [11, 5, 5, 5, 8],
        strides: List[int] = [4, 2, 2, 2, 1]
    ) -> None:
        super(ImageEncoder, self).__init__()
        self.in_channels = in_channels
        layers = []
        
        for ix in range(len(channels)):
            layers.append(
                nn.Sequential(
                    nn.Conv2d(
                        in_channels=in_channels, 
                        out_channels=channels[ix],
                        kernel_size=kernel_sizes[ix], 
                        stride=strides[ix], 
                        padding=1
                    ), 
                    nn.BatchNorm2d(channels[ix]),
                    nn.ReLU()
                )
            )
            in_channels = channels[ix]
        
        self.net = nn.Sequential(*layers)
        
    def sample_normal(self, std: torch.Tensor) -> torch.Tensor:
        sampler = torch.distributions.Normal(loc=0, scale=1)
        return sampler.sample(std.shape)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        latent = self.net(x)
        mu, std = latent[:, :200,], latent[:, 200:]
        z = self.sample_normal(std)
        latent = mu + z * std
        return latent, mu, std

In [3]:
enc = ImageEncoder()
enc

ImageEncoder(
  (net): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (2): Sequential(
      (0): Conv2d(128, 256, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (3): Sequential(
      (0): Conv2d(256, 512, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (4): Sequential(
      (0): Conv2d(512, 400, kernel_size=(8, 8), stride=(1, 1), padding=(1, 1))
    

In [4]:
summary(enc, input_size=(3, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 62, 62]          23,296
       BatchNorm2d-2           [-1, 64, 62, 62]             128
              ReLU-3           [-1, 64, 62, 62]               0
            Conv2d-4          [-1, 128, 30, 30]         204,928
       BatchNorm2d-5          [-1, 128, 30, 30]             256
              ReLU-6          [-1, 128, 30, 30]               0
            Conv2d-7          [-1, 256, 14, 14]         819,456
       BatchNorm2d-8          [-1, 256, 14, 14]             512
              ReLU-9          [-1, 256, 14, 14]               0
           Conv2d-10            [-1, 512, 6, 6]       3,277,312
      BatchNorm2d-11            [-1, 512, 6, 6]           1,024
             ReLU-12            [-1, 512, 6, 6]               0
           Conv2d-13            [-1, 400, 1, 1]      13,107,600
      BatchNorm2d-14            [-1, 40

### Generator

In [5]:
class Generator(nn.Module):
    def __init__(
        self, 
        in_channels: int = 200,
        channels: List[int] = [512, 256, 128, 64, 1],
        kernel_sizes: List[int] = [4, 4, 4, 4, 4],
        strides: List[int] = [1, 2, 2, 2, 2], 
        paddings: List[int] = [0, 1, 1, 1, 1]
    ) -> None:
        super(Generator, self).__init__()
        layers = []
        
        for ix in range(len(channels)):
            layers.append(
                nn.Sequential(
                    nn.ConvTranspose3d(
                        in_channels=in_channels, 
                        out_channels=channels[ix],
                        stride=strides[ix],
                        kernel_size=kernel_sizes[ix], 
                        padding=paddings[ix]
                    ), 
                    nn.BatchNorm3d(channels[ix]), 
                    nn.ReLU()
                )
            )
            in_channels = channels[ix]
        layers.append(nn.Sigmoid())
        self.net = nn.Sequential(*layers)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

In [6]:
gen = Generator()
gen

Generator(
  (net): Sequential(
    (0): Sequential(
      (0): ConvTranspose3d(200, 512, kernel_size=(4, 4, 4), stride=(1, 1, 1))
      (1): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): ConvTranspose3d(512, 256, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
      (1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (2): Sequential(
      (0): ConvTranspose3d(256, 128, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
      (1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (3): Sequential(
      (0): ConvTranspose3d(128, 64, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
      (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (4): Sequential(
      (0): ConvTranspose3d(64, 1

In [7]:
summary(gen, input_size=(200, 1, 1, 1))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
   ConvTranspose3d-1         [-1, 512, 4, 4, 4]       6,554,112
       BatchNorm3d-2         [-1, 512, 4, 4, 4]           1,024
              ReLU-3         [-1, 512, 4, 4, 4]               0
   ConvTranspose3d-4         [-1, 256, 8, 8, 8]       8,388,864
       BatchNorm3d-5         [-1, 256, 8, 8, 8]             512
              ReLU-6         [-1, 256, 8, 8, 8]               0
   ConvTranspose3d-7      [-1, 128, 16, 16, 16]       2,097,280
       BatchNorm3d-8      [-1, 128, 16, 16, 16]             256
              ReLU-9      [-1, 128, 16, 16, 16]               0
  ConvTranspose3d-10       [-1, 64, 32, 32, 32]         524,352
      BatchNorm3d-11       [-1, 64, 32, 32, 32]             128
             ReLU-12       [-1, 64, 32, 32, 32]               0
  ConvTranspose3d-13        [-1, 1, 64, 64, 64]           4,097
      BatchNorm3d-14        [-1, 1, 64,

### Discriminator

In [8]:
class Discriminator(nn.Module):
    def __init__(
        self,
        in_channels: int = 1, 
        channels: List[int] = [64, 128, 256, 512, 1],
        kernel_sizes: List[int] = [4, 4, 4, 4, 4],
        strides: List[int] = [4, 2, 2, 2, 1],
    ) -> None:
        super(Discriminator, self).__init__()
        layers = []
        
        for ix in range(len(channels)):
            layers.append(
                nn.Sequential(
                    nn.Conv3d(
                        in_channels=in_channels, 
                        out_channels=channels[ix], 
                        kernel_size=kernel_sizes[ix],
                        stride=strides[ix], 
                        padding=1
                    ), 
                    nn.BatchNorm3d(channels[ix]),
                    nn.LeakyReLU(0.2)
                )
            )
            in_channels = channels[ix]
        
        self.net = nn.Sequential(*layers)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

In [9]:
disc = Discriminator()
disc

Discriminator(
  (net): Sequential(
    (0): Sequential(
      (0): Conv3d(1, 64, kernel_size=(4, 4, 4), stride=(4, 4, 4), padding=(1, 1, 1))
      (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (1): Sequential(
      (0): Conv3d(64, 128, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
      (1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (2): Sequential(
      (0): Conv3d(128, 256, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
      (1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv3d(256, 512, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
      (1): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(neg

In [10]:
summary(disc, input_size=(1, 64, 64, 64))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1       [-1, 64, 16, 16, 16]           4,160
       BatchNorm3d-2       [-1, 64, 16, 16, 16]             128
         LeakyReLU-3       [-1, 64, 16, 16, 16]               0
            Conv3d-4         [-1, 128, 8, 8, 8]         524,416
       BatchNorm3d-5         [-1, 128, 8, 8, 8]             256
         LeakyReLU-6         [-1, 128, 8, 8, 8]               0
            Conv3d-7         [-1, 256, 4, 4, 4]       2,097,408
       BatchNorm3d-8         [-1, 256, 4, 4, 4]             512
         LeakyReLU-9         [-1, 256, 4, 4, 4]               0
           Conv3d-10         [-1, 512, 2, 2, 2]       8,389,120
      BatchNorm3d-11         [-1, 512, 2, 2, 2]           1,024
        LeakyReLU-12         [-1, 512, 2, 2, 2]               0
           Conv3d-13           [-1, 1, 1, 1, 1]          32,769
      BatchNorm3d-14           [-1, 1, 

## Utils

In [11]:
files = []

with open('../../data/IKEA_imgs_shapes/list/bed.txt') as f:
    files = f.readlines()

In [12]:
output_paths = []
for file in files:
    if file.strip() + '.mat' in os.listdir('../../data/IKEA_imgs_shapes/model/'):
        output_paths.append(
            os.path.join('../../data/IKEA_imgs_shapes/model/', f'{file.strip()}.mat')
        )
len(output_paths)

60

In [13]:
def load_data(
    pixel_path: Optional[str] = None, 
    voxel_path: Optional[str] = None,
):
    pixel, voxel = None, None
    if pixel_path:
        pixel = plt.imread(pixel_path)
    if voxel_path:
        voxel = io.loadmat(voxel_path)
    
    return pixel, voxel['voxel']

In [14]:
pixel_paths = glob('../../data/IKEA_imgs_shapes/bed/*.png')
pixel_paths = sorted(pixel_paths)
len(pixel_paths)

60

In [15]:
pixel_path, voxel_path = list(zip(pixel_paths, output_paths))[0]
pixel_path, voxel_path

('../../data/IKEA_imgs_shapes/bed/0001.png',
 '../../data/IKEA_imgs_shapes/model/IKEA_bed_MALM_malm_bed_2_obj0_object.mat')

In [16]:
pixels, voxels = load_data(pixel_path, voxel_path)

In [17]:
pixels.shape, voxels.shape

((244, 395, 3), (64, 64, 64))

### Config

In [18]:
class CFG:
    # thanks to awsaf for some amazing code style!!
    seed = 101
    exp_name = '2D_to_3D'
    comment = 'Trying out the 2D to 3D GAN from 3dgan.csail'
    model_name = 'VAE-3D-GAN'
    train_bs = 4
    valid_bs = 2*train_bs
    image_size = [256, 256]
    epochs = 5
    scheduler = 'CosineAnnealingLR'
    min_lr = 1e-6
    T_max = int(30000/train_bs*epochs)+50
    T_0 = 25
    warmup_epochs = 0
    n_accumulate  = max(1, 32//train_bs)
    n_fold = 5
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    num_classes = None
    weight_decay = 1e-5

cfg = CFG()

### Dataset

In [19]:
class VAE3dGANDataset(Dataset):
    def __init__(self, in_paths: List[str], out_paths: List[str], stage: str = "train") -> None:
        super(VAE3dGANDataset, self).__init__()
        self.in_paths = in_paths
        self.out_paths = out_paths
        self.stage = stage
        
        if self.stage == "train":
            '''Ref: https://www.kaggle.com/code/alexeyolkhovikov/segformer-training'''
            self.transforms = A.Compose([
                A.augmentations.crops.RandomResizedCrop(height=cfg.image_size[0], width=cfg.image_size[1]),
                A.augmentations.Rotate(limit=90, p=0.2),
                A.augmentations.HorizontalFlip(p=0.2),
                A.augmentations.VerticalFlip(p=0.2),
                A.augmentations.transforms.ColorJitter(p=0.5),
                A.OneOf([
                    A.OpticalDistortion(p=0.2),
                    A.GridDistortion(p=0.2),
                    A.PiecewiseAffine(p=0.2)
                ], p=0.5),
                A.OneOf([
                    A.HueSaturationValue(10, 15, 10),
#                     A.CLAHE(clip_limit=4),
                    A.RandomBrightnessContrast()
                ], p=0.5),
                A.Normalize()
            ])
        else:
            self.transforms = A.Compose([
                A.Resize(cfg.image_size[0], cfg.image_size[1]),
                A.Normalize()
            ])
        
    
    def __len__(self) -> int:
        return len(self.in_paths)
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        in_path = self.in_paths[idx]
        out_path = self.out_paths[idx]
        
        pixels, voxels = load_data(in_path, out_path)
        
        # the transform is applied to only the 2d image
        pixels = self.transforms(image=pixels)
        
        return {
            "x": torch.tensor(pixels['image'], dtype=torch.float32).permute(2, 0, 1).to(cfg.device),
            "y": torch.tensor(voxels, dtype=torch.float32).unsqueeze(0).to(cfg.device)
        }

In [20]:
dataset = VAE3dGANDataset(
    in_paths=pixel_paths, 
    out_paths=output_paths, 
    stage="train"
)

batch = dataset[0]
{k:v.shape for k, v in batch.items()}

{'x': torch.Size([3, 256, 256]), 'y': torch.Size([1, 64, 64, 64])}

In [21]:
dataset = DataLoader(dataset, batch_size=cfg.train_bs, shuffle=True)

In [22]:
batch = next(iter(dataset))
{k:v.shape for k, v in batch.items()}

{'x': torch.Size([4, 3, 256, 256]), 'y': torch.Size([4, 1, 64, 64, 64])}

### Optimizers and Schedulers

In [23]:
def get_scheduler(optimizer):
    '''Ref: https://www.kaggle.com/code/awsaf49/uwmgi-2-5d-train-pytorch?scriptVersionId=97382768&cellId=55'''
    if cfg.scheduler == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(
            optimizer=optimizer, 
            T_max=cfg.T_max, 
            eta_min=cfg.min_lr
        )
    elif cfg.scheduler == 'CosineAnnealingWarmRestarts':
        scheduler = lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer=optimizer,
            T_0=cfg.T_0, 
            eta_min=cfg.min_lr
        )
    elif cfg.scheduler == 'ReduceLROnPlateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(
            optimizer=optimizer, 
            mode='min',
            factor=0.1,
            patience=0.7, 
            threshold=0.0001, 
            min_lr=cfg.min_lr
        )
    elif cfg.scheduler == 'ExponentialLR':
        scheduler = lr_scheduler.ExponentialLR(
            optimizer=optimizer, 
            gamma=0.85
        )
    elif cfg.scheduler == None:
        return None
    else:
        raise NotImplementedError(f"The scheduler `{cfg.scheduler}` has not been implememted")
    return scheduler

### Reproducibility

In [24]:
def set_seed(seed: int = 42):
    '''Ref: https://www.kaggle.com/code/awsaf49/uwmgi-2-5d-train-pytorch?scriptVersionId=97382768&cellId=15'''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    
    # When running on CudNN  backend, two further options
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    print('> SEEDING DONE')
    
set_seed(cfg.seed)

> SEEDING DONE


### Create Folds

In [25]:
def create_folds(data_pth: str = '../../data/IKEA_imgs_shapes.csv'):
    df = pd.read_csv(data_pth)
    df = df.sample(frac=1).reset_index()
    df["kfold"] = -1
    kf = KFold(n_splits=cfg.n_fold)
    
    for fold, (train, valid) in enumerate(kf.split(df)):
        df.loc[valid, "kfold"] = fold
    
    df.to_csv('../../data/IKEA_imgs_shapes_folds.csv')

In [26]:
batch.keys()

dict_keys(['x', 'y'])

In [27]:
x, y = batch['x'], batch['y']
x.shape, y.shape

(torch.Size([4, 3, 256, 256]), torch.Size([4, 1, 64, 64, 64]))

In [28]:
encoder = ImageEncoder()
generator = Generator()
discriminator = Discriminator()

a, mu, std = encoder.forward(x)
b = generator.forward(a.unsqueeze(-1))
c = discriminator.forward(b).squeeze()
a.shape, b.shape, c.shape

(torch.Size([4, 200, 1, 1]), torch.Size([4, 1, 64, 64, 64]), torch.Size([4]))

In [29]:
def kl_divergence(mu: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
    return torch.mean(-0.5 * torch.sum(1 + std - mu ** 2 - std.exp(), dim = 1), dim = 0)

def gfv_loss(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return nn.MSELoss()(x, y)

kl_div = kl_divergence(mu=mu, std=std)
kl_div

tensor([[90.3263]], grad_fn=<MeanBackward1>)

In [30]:
encoder = ImageEncoder()
generator = Generator()
discriminator = Discriminator()
enc_optim = optim.Adam(encoder.parameters(), lr=3e-4)
gen_optim = optim.Adam(generator.parameters(), lr=3e-4)
dis_optim = optim.Adam(discriminator.parameters(), lr=3e-4)

In [31]:
y.shape, c.shape

(torch.Size([4, 1, 64, 64, 64]), torch.Size([4]))

In [33]:
for _ in tqdm(range(10)):
    a, mu, std = encoder.forward(x)
    b = generator.forward(a.unsqueeze(-1))
    c = discriminator.forward(b).squeeze()
    
    kl_loss = kl_divergence(mu, std)
    gan_loss = nn.BCEWithLogitsLoss()(y, torch.ones_like(y)) + nn.BCEWithLogitsLoss()(c, torch.zeros_like(c))
    

100%|███████████████████████████████████████████| 10/10 [00:03<00:00,  2.58it/s]
