In [1]:
import argparse
import os, sys
sys.path.append('/home/warrenzhao/PointNet-AE-3DMM')

import time
import copy
import math
import pickle
import statistics

import numpy as np
import pandas as pd
import open3d as o3d
import torchvision
import pytorch3d
from tqdm import tqdm
import matplotlib.pyplot as plt
import plotly.graph_objects as go

# Import pytorch dependencies
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.nn.modules.utils import _single, _pair, _triple
from torchsummary import summary

from pytorch3d.loss import chamfer_distance

# Import toolkits
from utils.visualization_3D_objects import *
from utils.preprocessing import *
from utils.read_object import *
from utils.model_averaging import *
from utils.model_PCA import *
from utils.morphable_model import *

from model.model import *

import emd

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
latent_dim = 128
input_dim = 1024
inter_dim = 512

class VAE(nn.Module):
    def __init__(self, input_dim=input_dim, inter_dim=inter_dim, latent_dim=latent_dim):
        super(VAE, self).__init__()

        self.encoder = nn.Sequential(
            nn.Linear(input_dim, inter_dim),
            nn.BatchNorm1d(inter_dim),
            nn.ReLU(),
            nn.Linear(inter_dim, latent_dim * 2)
        )

        self.decoder =  nn.Sequential(
            nn.Linear(latent_dim, inter_dim),
            nn.BatchNorm1d(inter_dim),
            nn.ReLU(),
            nn.Linear(inter_dim, input_dim)
        )

        self.kl = 0

    def reparameterise(self, mu, logvar):
        epsilon = torch.randn_like(mu)
        return mu + epsilon * torch.exp(logvar / 2)

    def forward(self, x):
        # reshape data
        org_size = x.size()
        batch = org_size[0]
        x = x.view(batch, -1)

        #encode
        h = self.encoder(x)
        mu, logvar = h.chunk(2, dim=1)

        self.kl = -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp())

        z = self.reparameterise(mu, logvar)
        recon_x = self.decoder(z).view(size=org_size)

        return recon_x, z

In [3]:
DATA_PATH = "/home/warrenzhao/PointNet-AE-3DMM/data/processedShapeNet/car/"
TRAIN_BATCH_SIZE = 12
VAL_BATCH_SIZE = 10
LATENT_DIM = 128
INITIAL_LR = 1e-3
MOMENTUM = 0.9
BETA = 1
EPOCHS = 350
DECAY_EPOCHS = np.arange(100, 400, 100)
DECAY = 0.1
CHECKPOINT_FOLDER = "/home/warrenzhao/PointNet-AE-3DMM/saved_model/"
MODEL_TYPE = "VAE_global_feat"

RESUME = False

### Import Preprocessed ModelNet40 cars (Run shapenet_preprocessing.ipynb first to get these data)

In [4]:
global_feat = np.loadtxt("../PointNetAE_global_feat.csv", delimiter=",", dtype=float)
data_loader = DataLoader(
    global_feat, 
    batch_size=TRAIN_BATCH_SIZE, 
    shuffle=True, 
    num_workers=4,
    drop_last=True
)

### Load saved model

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device =='cuda':
    print("Run on GPU...")
else:
    print("Run on CPU...")
    
model_test = VAE()
state_dict = torch.load(os.path.join(CHECKPOINT_FOLDER, MODEL_TYPE + '.pth')) # change the path to your own checkpoint file
model_test.cuda()
model_test.load_state_dict(state_dict['state_dict'])
model_test.eval()

print(state_dict['epoch'])

Run on GPU...
332


In [6]:
def extractGlobalFeatForAllData(data_loader, model_test, device):
    global_feat_lst = []
    for batch_idx, (inputs) in enumerate(data_loader):
        # copy inputs to device
        inputs = inputs.float().to(device)
        # compute the output and loss
        outputs, global_feat_i = model_test(inputs)
        global_feat_lst.append(global_feat_i.cpu().detach().numpy())
        
    global_feat = np.vstack(global_feat_lst)
    return global_feat

latent_vectors = extractGlobalFeatForAllData(data_loader, model_test, device)
print(latent_vectors.shape)

(1812, 128)


In [7]:
pointnetae = PointNet_AE(3, 2048)
state_dict = torch.load(os.path.join(CHECKPOINT_FOLDER, 'PointNetAE.pth')) # change the path to your own checkpoint file
pointnetae.cuda()
pointnetae.load_state_dict(state_dict['state_dict'])
pointnetae.eval()

print(state_dict['epoch'])

192


In [15]:
def computeAllData(global_feat, pointnetae, model_test, n_obj = 10, top_k = 25, mod_range = 10, n_step = 20):

    global_feat = global_feat[:n_obj]
    min_, max_ = np.min(global_feat, axis = 0), np.max(global_feat, axis = 0)
    mu, var = np.mean(global_feat, axis = 0), np.var(global_feat, axis = 0)
    sigificant_feature_idx = np.argsort(var)[::-1]

    X_list = dict()
    for idx in sigificant_feature_idx[:top_k]:
        X_feat_list = []
        lower_bound = min_[idx]
        higher_bound = max_[idx]
        for step in np.linspace(lower_bound, higher_bound, n_step):
            new_features = torch.from_numpy(global_feat).cuda()
            new_features[:, idx] = step
            X = pointnetae.decoder(model_test.decoder(new_features)).cpu().detach().numpy()
            X_feat_list.append(X)

        X_feat_list = np.asarray(X_feat_list)
        X_list[idx] = X_feat_list

    return X_list, sigificant_feature_idx

X_list, sigificant_feature_idx = computeAllData(latent_vectors, pointnetae, model_test)

In [9]:
# def examineObj(obj_idx):
#     for idx in X_list.keys():
#         draw3DpointsSlider(X_list[idx], idx, obj_idx)
#         end = int(input("Are you finished? 0: No, 1: Yes"))
#         if (end == 1):
#             break
        
#         clear_output(wait=True)

# examineObj(0)

In [17]:
feat_idx = 1
draw3DpointsSlider(X_list[sigificant_feature_idx[feat_idx]], sigificant_feature_idx[feat_idx], 0)