### Imports

In [21]:
from IPython.display import clear_output

In [24]:
%pip install -q path.py
%pip install -q pytorch3d
# https://github.com/facebookresearch/pifuhd/issues/77
%pip install -q 'torch==1.6.0+cu101' -f https://download.pytorch.org/whl/torch_stable.html
%pip install -q 'torchvision==0.7.0+cu101' -f https://download.pytorch.org/whl/torch_stable.html
%pip install -q 'pytorch3d==0.2.5'
%pip install -q Ninja
clear_output()

In [25]:
import numpy as np
import math
import random
import os
import torch
import scipy.spatial.distance
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms, utils
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pytorch3d

import plotly.graph_objects as go
import plotly.express as px

from path import Path

from pytorch3d.loss import chamfer

random.seed = 42

In [66]:
from torch.utils.tensorboard import SummaryWriter

In [26]:
!wget http://3dvision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip # /ModelNet40.zip - 40 classes
!unzip -q ModelNet10.zip

path = Path("ModelNet10")

folders = [dir for dir in sorted(os.listdir(path)) if os.path.isdir(path/dir)]

clear_output()
classes = {folder: i for i, folder in enumerate(folders)}
# classes

In [28]:
#!g1.1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [29]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.1)

In [69]:
from enum import Enum
class DecoderType(Enum):
    ORIGINAL = 1
    INCREASE_POINTS = 2
    INCREASE_CHANNELS = 3
    
class DataType(Enum):
    AUG_PRE = 1 # augmentation during training
    AUG_DUR = 2
    AUG_BOTH = 3

In [27]:
# all augmentations before training
# datatype = DataType.AUG_BEROFE
trainloader_pre = torch.load('dataloaders/dataloaders_beds_pre/trainloader.pth')
validloader_pre = torch.load('dataloaders/dataloaders_beds_pre/validloader.pth')

# all augmentations during training
# datatype = DataType.AUG_DURING
trainloader_dur = torch.load('dataloaders/dataloaders_beds_dur/trainloader.pth')
validloader_dur = torch.load('dataloaders/dataloaders_beds_dur/validloader.pth') 

# static (before training) augmentations + dynamic (during training) augmentations
# datatype = DataType.AUG_BOTH
trainloader_both = torch.load('dataloaders/dataloaders_beds_both/trainloader.pth')
validloader_both = torch.load('dataloaders/dataloaders_beds_both/validloader.pth')

In [64]:
class PointNetAE(nn.Module):
    def __init__(self, num_points=1024, z_dim=100, decoder_type=DecoderType.ORIGINAL):
        super(PointNetAE, self).__init__()
        self.num_points = num_points
        self.encoder = PointEncoder(num_points, z_dim=z_dim)

        if decoder_type is DecoderType.INCREASE_POINTS:
            self.decoder = PointDecoderPoints(num_points, z_dim=z_dim)
        elif decoder_type is DecoderType.INCREASE_CHANNELS:
            self.decoder = PointDecoderChannels(num_points, z_dim=z_dim)
        else:
            self.decoder = PointDecoderOriginal(num_points, z_dim=z_dim)

        self.name = self.decoder.name

    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var / 2)
        eps = torch.randn_like(std)
        return mu + std * eps

    def forward(self, x):
        x, mu, logvar = self.encoder(x)
        # x = self.reparameterize(mu, logvar)
        x = self.decoder(x)
        return x


class PointEncoder(nn.Module):
    def __init__(self, num_points, z_dim):
        super(PointEncoder, self).__init__()
        self.num_points = num_points
        self.feature_dim = z_dim
        self.convs = nn.Sequential(
            nn.Conv1d(3, 64, 1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Conv1d(128, num_points, 1),
            nn.BatchNorm1d(num_points),
        )

        self.dense = nn.Sequential(
            nn.Linear(num_points, 512),
            nn.ReLU(),
            nn.Linear(512, self.feature_dim)
        )

        self.dense.apply(init_weights)

        self.mu_fc = nn.Linear(self.feature_dim, z_dim)
        self.log_var_fc = nn.Linear(self.feature_dim, z_dim)

    def forward(self, x):
        x = self.convs(x)
        x, _ = torch.max(x, 2) # instead of maxpool
        x = x.view(-1, self.num_points)
        x = self.dense(x)
        x_relu = torch.relu(x)
        mu, log_var = self.mu_fc(x_relu), self.log_var_fc(x_relu)
        return x, mu, log_var


# ORIGINAL - all layers are linear
class PointDecoderOriginal(nn.Module):
    def __init__(self, num_points, z_dim):
        super(PointDecoderOriginal, self).__init__()
        self.num_points = num_points
        self.name = 'original'
        self.dense_layers = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.Dropout(0.1),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.Dropout(0.2),
            nn.ReLU(),
            nn.Linear(512, num_points),
            nn.Dropout(0.3),
            nn.Linear(num_points, num_points*3),
            nn.Tanh()
        )
        self.dense_layers.apply(init_weights)

    def forward(self, x):
        batchsize = x.size()[0]
        x = self.dense_layers(x)
        x = x.view(batchsize, 3, self.num_points)
        return x


# USE CONV1D TO INCREASE NUMBER OF POINTS (z_dim -> 1024)
# class PointDecoderPoints(nn.Module):
#     def __init__(self, num_points, z_dim):
#         super(PointDecoderPoints, self).__init__()
#         self.num_points = num_points
#         self.z_dim = z_dim
#         self.name = f'model_conv1d_{z_dim}_{num_points}'
#         self.conv_layers = nn.Sequential(
#             nn.Conv1d(z_dim, 256, 1),
#             nn.BatchNorm1d(256),
#             nn.ReLU(),
#             nn.Conv1d(256, 512, 1),
#             nn.BatchNorm1d(512),
#             nn.ReLU(),
#             nn.Conv1d(512, num_points, 1),
#             nn.BatchNorm1d(num_points),
#             nn.ReLU()
#         )
#         self.linear = nn.Sequential(
#             nn.Linear(num_points, num_points*3, 1),
#             nn.Dropout(0.4),
#             nn.Tanh()
#         )
#         self.linear.apply(init_weights)

#     def forward(self, x):
#         batchsize = x.size()[0]
#         x = x.reshape(batchsize, self.z_dim, 1)
#         x = self.conv_layers(x).reshape(batchsize, self.num_points)
#         x = self.linear(x).reshape(batchsize, 3, self.num_points)
#         return x


# USE CONV1D TO INCREASE NUMBER OF DIMENSIONS (1 -> 3)
# class PointDecoderChannels(nn.Module):
#     def __init__(self, num_points, z_dim):
#         super(PointDecoderChannels, self).__init__()
#         self.num_points = num_points
#         self.name = 'model_conv1d_1_3'
#         self.dense_layers = nn.Sequential(
#             nn.Linear(z_dim, 256),
#             nn.Dropout(0.1),
#             nn.ReLU(),
#             nn.Linear(256, 512),
#             nn.Dropout(0.2),
#             nn.ReLU(),
#             nn.Linear(512, num_points),
#             nn.Dropout(0.3),
#         )
#         self.conv = nn.Sequential(
#             nn.Conv1d(1, 3, 1),
#             nn.Tanh()
#         )
#         self.dense_layers.apply(init_weights)

#     def forward(self, x):
#         batchsize = x.size()[0]
#         x = self.dense_layers(x).reshape(batchsize, 1, self.num_points)
#         x = self.conv(x)
#         return x

In [65]:
encoder = PointEncoder(1024, 100)
decoder = PointDecoderOriginal(1024, 100)
for x, _ in beds_loader:
    x = x.float().permute(0, 2, 1)
    output, _, _ = encoder(x)
    print(decoder(output).shape)
    break

torch.Size([32, 3, 1024])


In [31]:
def train_pcautoencoder(autoencoder, x, loss_func, optimizer):
    '''
    loss function must be chamfer distance
    '''
    optimizer.zero_grad()
    x = x.float().to(device).permute(0, 2, 1)
    output = autoencoder(x)
    dist1, dist2 = loss_func(x, output)

    try:
        # dist2 might be None if x_normals and y_normals (args to loss_func) are None
        loss = (torch.mean(dist1)) + (torch.mean(dist2))
    except:
        loss = (torch.mean(dist1))

    loss.backward()
    optimizer.step()

    return loss.data.item()


def validate_pcautoencoder(autoencoder, x, loss_func):
    '''
    loss function must be chamfer distance
    '''
    with torch.no_grad():
        x = x.float().to(device).permute(0, 2, 1)
        output = autoencoder(x)
        dist1, dist2 = loss_func(x, output)

        try:
            # dist2 might be None if x_normals and y_normals (args to loss_func) are None
            loss = (torch.mean(dist1)) + (torch.mean(dist2))
        except:
            loss = (torch.mean(dist1))

        return loss.data.item()

In [32]:
def train_with_chamfer_dist(autoencoder, loaders_type, loss_func, optimizer,
                            train_func, validate_func, epochs=100, print_every_e=5, valid_every=5,
                            scheduler=None, summary_writer=None, model_name='model'):
    if loaders_type is DataType.AUG_PRE:
        train_loader, valid_loader = trainloader_pre, validloader_pre
    elif loaders_type is DataType.AUG_DUR:
        train_loader, valid_loader = trainloader_dur, validloader_dur
    else:
        train_loader, valid_loader = trainloader_both, validloader_both

    autoencoder.train()
    for epoch in range(1, epochs+1):
        losses = []
        for x, _ in train_loader:
            loss = train_func(autoencoder, x, loss_func, optimizer)
            losses.append(loss)
        if summary_writer is not None:
            summary_writer.add_scalar(f'{model_name}/train/loss', np.mean(losses), epoch)
        if scheduler:
            scheduler.step()

        if epoch % print_every_e == 0 or epoch == 1:
            print(f'{epoch}:\ttrain loss: {np.mean(losses)}')
        if epoch % valid_every == 0:
            valid_losses = []
            for x, _ in valid_loader:
                valid_loss = validate_func(autoencoder, x, loss_func)
                valid_losses.append(valid_loss)
            if summary_writer is not None:
                summary_writer.add_scalar(f'{model_name}/valid/loss', np.mean(valid_losses), epoch)
            print(f'\tvalidation loss: {np.mean(valid_losses)}')

In [None]:
writer = SummaryWriter()

In [53]:
#!g1.1
# CHANGE DECODER TYPE
pc_autoencoder = PointNetAE(num_points=1024, z_dim=100, decoder_type=DecoderType.Original)
pc_autoencoder.to(device)

optimizer = optim.AdamW(pc_autoencoder.parameters(), lr=0.0009, betas=(0.8, 0.8))
# optimizer = optim.SGD(pc_autoencoder.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2500, gamma=0.5)
loss_func = chamfer.chamfer_distance

# CHANGE DATA TYPE
train_with_chamfer_dist(pc_autoencoder, loaders_type=DataType.AUG_PRE, loss_func=loss_func,
                        valid_loader=beds_loader_valid, train_func=train_pcautoencoder, validate_func=validate_pcautoencoder,
                        epochs=10000, print_every_e=100, valid_every=100, scheduler=scheduler, summary_writer=writer, 
                        model_name=pc_autoencoder.name)

1:	train loss: 339.8019332885742
100:	train loss: 124.65575838088989
	validation loss: 123.31191507975261
200:	train loss: 124.54345321655273
	validation loss: 123.2471440633138
300:	train loss: 124.54922294616699
	validation loss: 123.24136861165364
400:	train loss: 124.55694103240967
	validation loss: 123.35435994466145
500:	train loss: 124.47421646118164
	validation loss: 123.30106862386067
600:	train loss: 124.6281590461731
	validation loss: 122.83753204345703
700:	train loss: 124.43552923202515
	validation loss: 122.8958511352539
800:	train loss: 124.52057123184204
	validation loss: 123.06599680582683
900:	train loss: 124.43160152435303
	validation loss: 123.08112335205078
1000:	train loss: 124.48707723617554
	validation loss: 122.97868347167969
1100:	train loss: 124.42672157287598
	validation loss: 122.91610463460286
1200:	train loss: 124.24827289581299
	validation loss: 123.41895294189453
1300:	train loss: 124.3877854347229
	validation loss: 122.91690317789714
1400:	train loss: 

KeyboardInterrupt: 

In [45]:
#!g1.1
torch.save(pc_autoencoder.state_dict(), 'MODEL_NAME.pth')