In [1]:
import numpy as np
import open3d as o3d
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import sys
sys.path.append('../')
from dataset.PointCloudDataset import PointCloudDataset
from dataset.voxelDataset import VoxelDataset

import os, glob

import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
dataset_train  = PointCloudDataset('../dataset/modelnet40_normal_resampled', 
                                                train=True, 
                                                ndata=4000,
                                                file_extension='.txt', 
                                                npoints=4000
                                            )

Loaded 4000 point clouds from 40 classes


In [3]:
data = dataset_train[100]

In [4]:

DEVICE = 'cpu'
from FoldingNet import FoldNet, Encoder, Decoder
checkpoints_path = '../checkpoints/foldingnet'
model = FoldNet(num_points=4000).to(DEVICE)

# load model if exists
models_saved = glob.glob(os.path.join(checkpoints_path, 'model_*.pth'))
if len(models_saved) > 0:
    # get most recent model
    epoches_done = max([int(m.split('_')[-1].split('.')[0]) for m in models_saved])
    model_path = os.path.join(checkpoints_path, f'model_{epoches_done}.pth')
    print(f"Loading model from {model_path}")
    model.load_state_dict(torch.load(model_path))

Loading model from ../checkpoints/foldingnet/model_128.pth


In [5]:
point_cloud, label = data

point_cloud = point_cloud.unsqueeze(0)
point_cloud = point_cloud.to(DEVICE)

decoded, encoded = model(point_cloud)
decoded = decoded.squeeze(0).detach().cpu().numpy()
point_cloud = point_cloud.squeeze(0).detach().cpu().numpy()

In [6]:
#open3d visualization

pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(point_cloud)
o3d.visualization.draw_geometries([pcd])

pcd_dec = o3d.geometry.PointCloud()
pcd_dec.points = o3d.utility.Vector3dVector(decoded)
o3d.visualization.draw_geometries([pcd_dec])


In [7]:
encoded

tensor([[[-0.9603, -0.8074, -0.8323, -0.7691, -0.9854, -1.5732, -4.5183,
          -1.9784, -1.0632, -1.4052, -1.2694, -1.3722, -0.9359,  0.6353,
          -1.1027, -0.8134, -1.4012, -0.8539, -0.7609, -0.6760, -1.0719,
          -0.9875, -1.7166, -1.0708, -1.2411, -5.1059, -1.2581,  0.1480,
          -1.0721, -0.8017, -4.3857, -0.6380, -2.7391, -0.6922, -1.0916,
           2.8229, -1.0732, -0.3348, -1.0832, -0.8765, -1.4754, -1.1050,
          -3.8197, -0.9713, -0.6346, -0.9617, -0.7422, -3.5060, -1.5161,
          -0.6795, -1.1383,  1.0674, -0.7972, -1.0096, -0.7001, -0.7931,
          -3.1221,  0.1971, -1.0155, -0.9073, -1.4636, -0.6644, -1.0021,
          -0.8622, -1.0256, -1.4149,  1.5893, -1.9305, -0.7153, -0.6993,
          -1.1638,  0.7146, -0.8357, -1.0906, -1.5302, -1.1276, -1.6135,
          -0.7302, -0.9347, -0.7920, -1.2726, -1.1220, -0.9697, -4.1041,
          -1.9554, -0.5351, -1.1530, -0.7373, -1.0007, -1.1521, -0.8152,
          -1.1131,  2.3102, -0.8226, -0.7165, -0.94

: 