In [1]:
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms, utils

import numpy as np
import easydict
import h5py
import argparse
import os
import open3d as o3d
import itertools
import math, random

from path import Path
import scipy.spatial.distance
import plotly.graph_objects as go
import plotly.express as px
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

from model.model import PointNetAE
from model.model1 import PointNetAE1
from dataset.dataset import Dataset
from dataset.dataset import create_datasets_and_dataloaders


In [2]:
args = easydict.EasyDict({
    'train': False,
    'in_data_file': 'data/ModelNet/modelnet_classification.h5',
    'batch_size': 1, # data directory
    'model': 'saved_models/autoencoder_50.pth', # model path           
    'out_norm_input': True,
    'n_epochs': 50,         # number of epochs
    'n_workers': 0
})

In [3]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
print(device)

cuda:0


In [4]:
def visualize_rotate(data):
    x_eye, y_eye, z_eye = 1.25, 1.25, 0.8
    frames=[]

    def rotate_z(x, y, z, theta):
        w = x+1j*y
        return np.real(np.exp(1j*theta)*w), np.imag(np.exp(1j*theta)*w), z

    for t in np.arange(0, 10.26, 0.1):
        xe, ye, ze = rotate_z(x_eye, y_eye, z_eye, -t)
        frames.append(dict(layout=dict(scene=dict(camera=dict(eye=dict(x=xe, y=ye, z=ze))))))
    fig = go.Figure(data=data,
        layout=go.Layout(
            updatemenus=[dict(type='buttons',
                showactive=False,
                y=1,
                x=0.8,
                xanchor='left',
                yanchor='bottom',
                pad=dict(t=45, r=10),
                buttons=[dict(label='Play',
                    method='animate',
                    args=[None, dict(frame=dict(duration=50, redraw=True),
                        transition=dict(duration=0),
                        fromcurrent=True,
                        mode='immediate'
                        )]
                    )
                ])]
        ),
        frames=frames
    )

    return fig


def pcshow(xs,ys,zs):
    data=[go.Scatter3d(x=xs, y=ys, z=zs,
                                   mode='markers')]
    fig = visualize_rotate(data)
    fig.update_traces(marker=dict(size=2,
                      line=dict(width=2,
                      color='DarkSlateGrey')),
                      selector=dict(mode='markers'))
    fig.show()

In [5]:
class Normalize(object):
    def __call__(self, pointcloud):
        assert len(pointcloud.shape)==2
        
        norm_pointcloud = pointcloud - np.mean(pointcloud, axis=0) 
        norm_pointcloud /= np.max(np.linalg.norm(norm_pointcloud, axis=1))

        return  norm_pointcloud

In [22]:
in_dim = 3
num_points = 2048

autoencoder = PointNetAE1(in_dim, num_points)
autoencoder.load_state_dict(torch.load(args.model))
autoencoder.to(device)
autoencoder = autoencoder.eval()

train_dataset, train_dataloader ,test_dataset, test_dataloader, n_classes = create_datasets_and_dataloaders(args)

count = 1
for i, data in enumerate(train_dataloader):
    if i == 133:
    
        points, gt_classes = data
        print(points.shape[0], points.shape[1], points.shape[2])
        points = points.to(device)
        
        points0 = points.view(-1,3)
        points0 = np.array(points0.cpu())
        print(points0.shape[0], points0.shape[1])
        xx = points0[:, 0]
        xy = points0[:, 1]
        xz = points0[:, 2]
        pcshow(xx,xy,xz)

    
        
        with torch.no_grad():
            cloud = autoencoder.eval()(points)
            print(cloud.shape[0], cloud.shape[1], cloud.shape[2])
            cloud = cloud.transpose(2,1)
            cloud = cloud.view(-1,3)
            print(cloud.shape[0], cloud.shape[1])
            
            points = np.array(cloud.cpu())
            print(points.shape[0], points.shape[1])
            x = points[:, 0]
            y = points[:,1]
            z = points[:,2]
            
            pcshow(x,y,z)
        
    
        

        
        
        


1 2048 3
2048 3


1 3 2048
2048 3
2048 3
