In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
sys.path.insert(1, '/home/cem/Documents/imps/src')

import numpy as np
import open3d as o3d
import torch
from matplotlib import cm
import matplotlib.pyplot as plt
import torch
import torch.optim as optim

from imps.data import ScanNetScene, CLASS_NAMES
from imps.sqn.model import SQN
from imps.sqn.data_utils import prepare_input

SCENE_DIR = '/home/cem/Documents/datasets/ScanNet/scans/scene0000_00'
N_POINTS = int(1.5e5)
# Not important we are not using this here yet. Keep this small for quick data processing
RESOLUTION = 25
SIGMAS = np.array([0.5, 0.1, 0.01])
DEVICE = 'cpu'

In [3]:
scene = ScanNetScene(SCENE_DIR)

_, surface_points, surface_colors, vicinity_points, vicinity_distances, _ = scene.create_if_data(
    RESOLUTION, N_POINTS, SIGMAS
)

# features = np.concatenate([surface_points, surface_colors], axis=-1)
features = torch.FloatTensor(surface_points).unsqueeze(0).to(DEVICE)
xyz = torch.FloatTensor(surface_points).unsqueeze(0)
query = torch.FloatTensor(vicinity_points).unsqueeze(0)

In [4]:
input_points, input_neighbors, input_pools = prepare_input(xyz, k=16, num_layers=3, sub_sampling_ratio=4, 
                                                           device=DEVICE)

sqn = SQN(d_feature=3, d_in=8, encoder_dims=[8, 32, 64], decoder_dims=[64, 32, 1], device=DEVICE,
          activation=torch.nn.ReLU(inplace=True))
sqn.load_state_dict(torch.load('../../data/sqn-dist-pred'));

In [5]:
sqn.eval()

with torch.no_grad():
    query_pred = sqn.forward(features, input_points, input_neighbors, input_pools, query)
    pred_dists = query_pred.detach().squeeze().cpu().numpy()

In [6]:
np.abs(pred_dists - vicinity_distances).mean() * 8

0.18752353003259303

In [7]:
pts = query.squeeze().detach().cpu().numpy()
dst = pred_dists
dst_norm = (dst - dst.min()) / (dst.max() - dst.min())

viridis = cm.get_cmap('Reds')
vic_pcd = o3d.geometry.PointCloud()
vic_pcd.points = o3d.utility.Vector3dVector(pts)
vic_pcd.colors = o3d.utility.Vector3dVector(viridis(dst_norm)[:, :-1])
    
o3d.visualization.draw_geometries([vic_pcd])

In [None]:
torch.save(sqn.state_dict(), '../../data/sqn-dist-pred')