In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
sys.path.insert(1, '/home/cem/Documents/imps/src')

import numpy as np
import open3d as o3d
import torch
from matplotlib import cm
import matplotlib.pyplot as plt
import torch
import torch.optim as optim

from imps.data import ScanNetScene, CLASS_NAMES
from imps.sqn.model import SQN
from imps.sqn.data_utils import prepare_input

SCENE_DIR = '/home/cem/Documents/datasets/ScanNet/scans/scene0000_00'
N_POINTS = int(1.5e5)
# Not important we are not using this here yet. Keep this small for quick data processing
RESOLUTION = 25
SIGMAS = np.array([0.5, 0.1, 0.01])
LABEL_RATIO = 0.5
N_LABEL = int(N_POINTS*LABEL_RATIO)
DEVICE = 'cuda'

def get_loss(pred_dist, distances):
    dist_criterion = torch.nn.L1Loss(reduction='none')    
    dist_loss = dist_criterion(pred_dist.squeeze(), distances.squeeze()).sum(dim=-1).mean()
    return dist_loss

In [2]:
scene = ScanNetScene(SCENE_DIR)

_, surface_points, surface_colors, vicinity_points, vicinity_distances, _ = scene.create_if_data(
    RESOLUTION, N_POINTS, SIGMAS
)

query_idxs = np.random.choice(N_POINTS, N_LABEL, replace=False)
query_points = vicinity_points[query_idxs]
query_labels = vicinity_distances[query_idxs]

In [3]:
# features = np.concatenate([surface_points, surface_colors], axis=-1)
features = torch.FloatTensor(surface_points).unsqueeze(0).to(DEVICE)
xyz = torch.FloatTensor(surface_points).unsqueeze(0)
query = torch.FloatTensor(query_points).unsqueeze(0).to(DEVICE)
query_labels = torch.FloatTensor(query_labels).unsqueeze(0).to(DEVICE)

input_points, input_neighbors, input_pools = prepare_input(xyz, k=16, num_layers=3, sub_sampling_ratio=4, 
                                                           device=DEVICE)

sqn = SQN(d_feature=3, d_in=8, encoder_dims=[8, 32, 64], decoder_dims=[64, 32, 1], device=DEVICE,
          activation=torch.nn.ReLU(inplace=True))
optimizer = optim.Adam(sqn.parameters(), lr=1e-4)

In [4]:
sqn.train()

for e in range(1000):
    optimizer.zero_grad()
    
    query_pred = sqn.forward(features, input_points, input_neighbors, input_pools, query)
    loss = get_loss(query_pred, query_labels)
    loss.backward()
    
    optimizer.step()
    
    print(f"Epoch {e+1}: {round(loss.item(), 5)}")

Epoch 1: 17993.69531
Epoch 2: 16727.82812
Epoch 3: 15528.44727
Epoch 4: 14398.04883
Epoch 5: 13347.93262
Epoch 6: 12376.58789
Epoch 7: 11474.95215
Epoch 8: 10637.53711
Epoch 9: 9861.80273
Epoch 10: 9148.3418
Epoch 11: 8493.79102
Epoch 12: 7884.93555
Epoch 13: 7314.35449
Epoch 14: 6787.20215
Epoch 15: 6310.84668
Epoch 16: 5877.09619
Epoch 17: 5483.12695
Epoch 18: 5124.20312
Epoch 19: 4803.87354
Epoch 20: 4519.8501
Epoch 21: 4262.81201
Epoch 22: 4027.73413
Epoch 23: 3811.07007
Epoch 24: 3610.65918
Epoch 25: 3428.05322
Epoch 26: 3262.20044
Epoch 27: 3109.23999
Epoch 28: 2967.7666
Epoch 29: 2838.74219
Epoch 30: 2720.2085
Epoch 31: 2611.76294
Epoch 32: 2512.98389
Epoch 33: 2422.74487
Epoch 34: 2341.90332
Epoch 35: 2269.99463
Epoch 36: 2207.35864
Epoch 37: 2152.17676
Epoch 38: 2103.46582
Epoch 39: 2059.15283
Epoch 40: 2019.19727
Epoch 41: 1982.54993
Epoch 42: 1949.20276
Epoch 43: 1918.60791
Epoch 44: 1890.47803
Epoch 45: 1865.13159
Epoch 46: 1842.70813
Epoch 47: 1822.75171
Epoch 48: 1805.253

KeyboardInterrupt: 

In [5]:
sqn.eval()

with torch.no_grad():
    query_pred = sqn.forward(features, input_points, input_neighbors, input_pools, query)
    pred_dists = query_pred.detach().squeeze().cpu().numpy()

In [None]:
pts = query.squeeze().detach().cpu().numpy()
dst = pred_dists
dst_norm = (dst - dst.min()) / (dst.max() - dst.min())

viridis = cm.get_cmap('Reds')
vic_pcd = o3d.geometry.PointCloud()
vic_pcd.points = o3d.utility.Vector3dVector(pts)
vic_pcd.colors = o3d.utility.Vector3dVector(viridis(dst_norm)[:, :-1])
    
o3d.visualization.draw_geometries([vic_pcd])

In [None]:
torch.save(sqn.state_dict(), '../../data/sqn-dist-pred')