In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
sys.path.insert(1, '/home/cem/Documents/imps/src')

import numpy as np
import open3d as o3d
import torch
from matplotlib import cm
import matplotlib.pyplot as plt
from sklearn.metrics import jaccard_score
import pandas as pd

from imps.data import ScanNetScene, CLASS_NAMES
from imps.sqn.model import SQN
from imps.sqn.data_utils import prepare_input

SCENE_DIR = '/home/cem/Documents/datasets/ScanNet/scans/scene0000_00'
N_POINTS = int(1.5e5)
# Not important we are not using this here yet. Keep this small for quick data processing
RESOLUTION = 25
SIGMAS = None
DEVICE = 'cpu'

IGNORED_LABELS = [0]

def get_iou(logits, labels):
    preds = np.argmax(logits, axis=-1)
    ious = []
    
    for c in range(len(CLASS_NAMES)):
        iou = jaccard_score((labels==c).astype(int), (preds==c).astype(int), pos_label=1)
        ious.append(iou)
        
    return np.array(ious)

In [None]:
scene = ScanNetScene(SCENE_DIR)

voxel_grid, surface_points, surface_colors, vicinity_points, vicinity_distances, point_labels = scene.create_if_data(
    RESOLUTION, N_POINTS, SIGMAS
)

class_counts = []
for c in range(len(CLASS_NAMES.keys())):
    class_counts.append(np.sum(point_labels == c))
class_counts = np.array(class_counts)

for ign in IGNORED_LABELS:
    class_counts[ign] = 0
class_weights = class_counts / class_counts.sum()

In [None]:
features = torch.FloatTensor(surface_colors).unsqueeze(0).to(DEVICE)
xyz = torch.FloatTensor(surface_points).unsqueeze(0)

input_points, input_neighbors, input_pools = prepare_input(xyz, k=16, num_layers=3, sub_sampling_ratio=4, 
                                                           device=DEVICE)

sqn = SQN(d_feature=3, d_in=8, encoder_dims=[8, 32, 64], decoder_dims=[64, len(CLASS_NAMES)], device=DEVICE)
sqn.load_state_dict(torch.load('../../data/sqn-0.0001'))
sqn.eval();

In [None]:
with torch.no_grad():
    all_logits = sqn.forward(features, input_points, input_neighbors, input_pools, xyz)
    all_logits = all_logits.squeeze().detach().numpy()

In [None]:
df = pd.DataFrame()
df['class'] = CLASS_NAMES.keys()
df['iou'] = get_iou(all_logits, point_labels)
df['weight'] = class_weights
df['iou_weighted'] = df['iou'] * df['weight']

print("mIOU:", df.iou_weighted.sum())

In [None]:
df

In [None]:
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(surface_points)
pcd.colors = o3d.utility.Vector3dVector(scene.colorize_labels(all_logits.argmax(axis=-1)))

o3d.visualization.draw_geometries([pcd]) 