In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
sys.path.insert(1, '/home/cem/Desktop/ov-workspace/src')

import numpy as np
import open3d as o3d
import torch
from matplotlib import cm
import matplotlib.pyplot as plt

from imps.data import ScanNetScene, CLASS_NAMES
from imps.ndf.local_model import MiniNDF

SCENE_DIR = '/home/cem/Documents/datasets/ScanNet/scans/scene0000_00'
N_POINTS = int(2e5)
RESOLUTION = 125
SIGMAS = np.array([0.5, 0.1, 0.01])

In [None]:
scene = ScanNetScene(SCENE_DIR)
vis = scene.create_if_data(
    RESOLUTION, N_POINTS, SIGMAS, True, o3d_format=True
)

In [None]:
# Visulalize to make sure the data is correct
o3d.visualization.draw_geometries([vis[0], vis[2]])

In [None]:
# Now generate numeric data instead of visual
voxel_grid, surface_points, vicinity_points, vicinity_distances, point_labels = scene.create_if_data(
    RESOLUTION, N_POINTS, SIGMAS, True
)

In [None]:
# Concatenation order is [surface, vicinity]

query_points = np.concatenate([
    surface_points, vicinity_points
], axis=0)


distances = np.concatenate([
    np.zeros(surface_points.shape[0]),
    vicinity_distances
], axis=0)

semantic_labels = np.concatenate([
    point_labels,
    np.zeros(vicinity_points.shape[0], dtype=int)
])

# query_points = vicinity_points
# distances = vicinity_distances
# semantic_labels = np.zeros(vicinity_points.shape[0], dtype=int)

# Further down-sample the number of points
ratio = 0.5
N_points = query_points.shape[0]
N_sub = int(ratio*N_points)

sub_idxs = np.random.choice(N_points, N_sub, replace=False)
shuffle_idxs = np.random.permutation(N_points)[sub_idxs]

# 0 is both vicinity points and "unannotated" points
ignored_label_inds = [0]
class_counts = []
for c in range(len(CLASS_NAMES.keys())):
    class_counts.append(np.sum(semantic_labels == c))
class_counts = np.array(class_counts)

for ign in ignored_label_inds:
    class_counts[ign] = 0
class_weights = class_counts / class_counts.sum()

query_points = torch.FloatTensor(query_points[shuffle_idxs])
distances = torch.FloatTensor(distances[shuffle_idxs])
semantic_labels = torch.LongTensor(semantic_labels[shuffle_idxs])
voxel_grid = torch.FloatTensor(voxel_grid)
class_weights = torch.FloatTensor(class_weights)

In [None]:
d = distances.squeeze().cpu().numpy()
d = (d-d.min()) / (d.max() - d.min())
viridis = cm.get_cmap('Reds')

seg_pcd = o3d.geometry.PointCloud()
seg_pcd.points = o3d.utility.Vector3dVector(query_points)
# seg_pcd.colors = o3d.utility.Vector3dVector(viridis(d)[:, :-1])
seg_pcd.colors = o3d.utility.Vector3dVector(scene.colorize_labels(semantic_labels))

o3d.visualization.draw_geometries([seg_pcd])

In [None]:
model_input = voxel_grid[None, :, :].cuda()
model_query = query_points[None, :, :].cuda()
distance_gold = distances[None, :].cuda()
semantic_gold = semantic_labels[None, :].cuda()
class_weights = class_weights.cuda()

model = MiniNDF(len(class_weights), is_cuda=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [None]:
model.train()

for e in range(1000):
    optimizer.zero_grad()
    
    encoding = model.encoder(model_input)
    pred_dist, pred_logits = model.decoder(model_query, *encoding)
    
    dist_criterion = torch.nn.L1Loss(reduction='none')
    seg_criterion = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='mean')
    
    dist_loss = dist_criterion(pred_dist, distance_gold).sum(dim=-1).mean()
    semantic_loss = seg_criterion(pred_logits, semantic_gold)
    loss = dist_loss + semantic_loss
    
    loss.backward()
    optimizer.step()
    
    print(f"Epoch {e}\n\tDistance Loss:{dist_loss}\t Semantic Loss:{semantic_loss}")
     #print(f"Epoch {e}\n\tDistance Loss:{dist_loss}")

In [None]:
model.eval()
with torch.no_grad():
    encoding = model.encoder(model_input)
    pred_dist, pred_logits = model.decoder(model_query, *encoding)

In [None]:
pts = model_query.squeeze().detach().cpu().numpy()
dst = pred_dist.squeeze().detach().cpu().numpy()
dst_norm = (dst - dst.min()) / (dst.max() - dst.min())

viridis = cm.get_cmap('Reds')
vic_pcd = o3d.geometry.PointCloud()
vic_pcd.points = o3d.utility.Vector3dVector(pts)
vic_pcd.colors = o3d.utility.Vector3dVector(viridis(dst_norm)[:, :-1])
    
o3d.visualization.draw_geometries([vic_pcd])

In [None]:
model.eval()
with torch.no_grad():
    surface_query = torch.FloatTensor(surface_points[None, :, :]).cuda()
    encoding = model.encoder(model_input)
    _, pred_logits = model.decoder(surface_query, *encoding)

logits = pred_logits.squeeze().detach().cpu().numpy()
pred_labels = logits.argmax(axis=0)

seg_pcd = o3d.geometry.PointCloud()
seg_pcd.points = o3d.utility.Vector3dVector(surface_points)
seg_pcd.colors = o3d.utility.Vector3dVector(scene.colorize_labels(pred_labels))
    
o3d.visualization.draw_geometries([seg_pcd])