In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
sys.path.insert(1, '/home/cem/Documents/imps/src')

import numpy as np
import matplotlib.pyplot as plt
import open3d as o3d
import torch
import torch.optim as optim
from sklearn.metrics import jaccard_score
import matplotlib.pyplot as plt

from sklearn.decomposition import PCA
from matplotlib import cm

from imps.sqn.model import SQN
from imps.sqn.data_utils import prepare_input
from imps.metrics import compute_iou

DEVICE = 'cpu'
DATA_ROOT = '/mnt/data.nas/shareddata/6G-futurelab/synthetic_room_dataset/rooms_04'
scene = '00000004'
POS_EMBEDDING = True
MODEL_DIR = './logs/occ-semantic-emb-skip-seed=31-lr=sch-1638545527/model'

scene_dir = os.path.join(DATA_ROOT, scene)

# Classes taken from: https://github.com/autonomousvision/convolutional_occupancy_networks/blob/master/scripts/dataset_synthetic_room/build_dataset.py#L29
# https://gist.github.com/tejaskhot/15ae62827d6e43b91a4b0c5c850c168e
classes = ['04256520', '03636649', '03001627', '04379243', '02933112']
class_names = ['sofa', 'lamp', 'chair', 'table', 'cabinet']
classes, class_names = zip(*sorted(zip(classes, class_names)))
class2name = {
    -1: "ground-plane",
    6: "free-space"
}
for i in range(len(class_names)):
    class2name[i] = class_names[i]

def sigmoid(x):
    return 1/(1 + np.exp(-x))
    
def get_loss(logits, labels, pos_weight):
    n_batch = logits.shape[0]
    logits = logits.reshape(n_batch, -1)
    labels = labels.reshape(n_batch, -1)

    criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='none')
    output_loss = criterion(logits, labels)
    output_loss = output_loss.mean()
    
    return output_loss

def get_semantic_iou(logits, labels, n_class):
    preds = np.argmax(logits, axis=-1)
    ious = []
    
    for c in range(n_class):
        iou = jaccard_score((labels==c).astype(int), (preds==c).astype(int), pos_label=1)
        ious.append(iou)
        
    return np.array(ious)

def pos_embed(pos, L=10):
    embs = []
    
    for l in range(L):
        sin_emb = np.sin((2^l)*np.pi*pos)
        cos_emb = np.cos((2^l)*np.pi*pos)
        embs += [sin_emb, cos_emb]
    
    return np.concatenate(embs, axis=-1).astype(np.float)

def get_data(scene_points, query_points, query_occ, embedding=False, seed=None):
    
    if embedding:
        features = torch.FloatTensor(pos_embed(scene_points)).unsqueeze(0).to(DEVICE)
    else:
        features = torch.FloatTensor(scene_points).unsqueeze(0).to(DEVICE)
        
    xyz = torch.FloatTensor(scene_points).unsqueeze(0)
    
    # This should be permutation invariant but it is not! WHY!!!
    # Hypothesis: When we permute and sub-sample, during the kNN up-sampling part, the
    # corresponding features will change.
    # We have to permute the input since it is ordered wrt. to objects.
    
    if (seed is not None) and (seed != -1):
        torch.manual_seed(seed)
    elif seed == -1:
        torch.manual_seed(31)
    
    point_perm = torch.randperm(xyz.size()[1])
    xyz = xyz[:, point_perm]
    features = features[:, point_perm]

    query = torch.FloatTensor(query_points).unsqueeze(0).to(DEVICE)
    query_labels = torch.FloatTensor(query_occ).unsqueeze(0).to(DEVICE)

    input_points, input_neighbors, input_pools = prepare_input(xyz, k=8, num_layers=3, sub_sampling_ratio=4, 
                                                           device=DEVICE)
    
    return features, input_points, input_neighbors, input_pools, query, query_labels

In [None]:
scene_pcd = o3d.io.read_point_cloud(os.path.join(scene_dir, 'pointcloud0.ply'));

scene = np.load(os.path.join(scene_dir, 'pointcloud', 'pointcloud_00.npz'))
query_iou = np.load(os.path.join(scene_dir, 'points_iou', 'points_iou_00.npz'))

scene_points = scene['points']
query_points = query_iou['points']
query_occ = np.unpackbits(query_iou['occupancies'])
query_semantics = query_iou['semantics']

pos_w = np.sum(query_occ==0) / np.sum(query_occ==1)
pos_w = torch.FloatTensor([pos_w]).to(DEVICE)

query_pcd = o3d.geometry.PointCloud()
query_pcd.points = o3d.utility.Vector3dVector(query_points[query_occ==1])

sqn = SQN(d_feature=60, d_in=64, encoder_dims=[32, 64, 128], decoder_dims=[128, 32, 1], device=DEVICE, 
          skip_connections=True, second_head=5)

features, input_points, input_neighbors, input_pools, query, query_labels = get_data(scene_points, 
                                                                                     query_points,
                                                                                     query_occ,
                                                                                     embedding=POS_EMBEDDING,
                                                                                     seed=31)

sqn.load_state_dict(torch.load(MODEL_DIR));

### Evaluate Accuracy

In [None]:
with torch.no_grad():
    sqn.eval()
    logits, _ = sqn.forward(features, input_points, input_neighbors, input_pools, query)
    pred = torch.nn.Sigmoid()(logits)

    pred = logits.detach().cpu().numpy().squeeze()
    pred = (pred > 0.5).astype(np.int)
    gold = query_labels.detach().cpu().numpy().squeeze()
    
    print(jaccard_score(gold, pred))

In [None]:
occ_pts = query_points[pred == 1]

occ_pcd = o3d.geometry.PointCloud()
occ_pcd.points = o3d.utility.Vector3dVector(occ_pts)

o3d.visualization.draw_geometries([occ_pcd])

### Extract features

In [None]:
with torch.no_grad():
    sqn.eval()
    encoder_list = sqn.encoder(features, input_points, input_neighbors, input_pools)
    layer_latents = sqn.get_features(query, encoder_list, input_points)
    
    latent = torch.cat(layer_latents, dim=-1).squeeze().detach().cpu().numpy()
    layer_latents = [x.squeeze().detach().cpu().numpy() for x in layer_latents]

### Extract scene-specific latents

In [None]:
f = latent

pca = PCA(2).fit(f)
latent_pc = pca.transform(f)

### Plot semantics (one-shot)

In [None]:
fig = plt.figure(figsize=(11 ,11))
# Ignore wall, plane and free-space
ignore = (-1, 6)

for i in np.unique(query_semantics):
    if i not in ignore:
        mask = query_semantics == i
        plt.scatter(latent_pc[mask, 0], latent_pc[mask, 1], label=class2name[i])
    
plt.legend()
plt.grid()

In [None]:
from imps.utils.libmise import MISE

resolution0 = 16
upsampling_steps = 3
padding = 0.1
box_size = 1 + padding
threshold = 0.99

mesh_extractor = MISE(resolution0, upsampling_steps, threshold)

with torch.no_grad():
    sqn.eval()
    encoder_list = sqn.encoder(features, input_points, input_neighbors, input_pools)

eval_points = []
eval_values = []

points = mesh_extractor.query()
while points.shape[0] != 0:
    # Query points
    pointsf = points / mesh_extractor.resolution
    # Normalize to bounding box
    pointsf = box_size * (pointsf - 0.5)
    eval_points.append(pointsf)
    pointsf = torch.FloatTensor(pointsf).to(DEVICE).unsqueeze(0)
    # Evaluate model and update
    # values = self.eval_points(pointsf, c, **kwargs).cpu().numpy()
    with torch.no_grad():
        values, _ = sqn.decoder(encoder_list, input_points, pointsf)
        # values = torch.nn.Sigmoid()(values)
        values = values.squeeze().cpu().numpy()
        eval_values.append(sigmoid(values))
        
    values = values.astype(np.float64)
    mesh_extractor.update(points, values)
    points = mesh_extractor.query()

value_grid = mesh_extractor.to_dense()

In [None]:
# This does not work yet since it is so overfitted to the points, when we query points around the scene,
# it fails to predict the points.
cmap = cm.get_cmap('Reds')
idx = 5

eval_pcd = o3d.geometry.PointCloud()
eval_pcd.points = o3d.utility.Vector3dVector(eval_points[idx])
eval_pcd.colors = o3d.utility.Vector3dVector(cmap(eval_values[idx])[:, :-1])

o3d.visualization.draw_geometries([occ_pcd, eval_pcd])

In [None]:
import mcubes
import trimesh

vertices, triangles = mcubes.marching_cubes(sigmoid(value_grid), threshold)
mesh = trimesh.Trimesh(vertices, triangles)
mesh.show()