In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
#sys.path.insert(1, '/home/samanti/Desktop/THESIS_LMT/imps/src')

import numpy as np
import open3d as o3d
import torch
from matplotlib import cm
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.metrics import jaccard_score

from imps.data.scannet import ScanNetScene, CLASS_NAMES

from imps.sqn.model import Randla
from imps.sqn.data_utils import prepare_input
from imps.point_augment.Common import loss_utils

from imps.point_augment.Augmentor.augmentor import Augmentor

SCENE_DIR = '/app/mnt/scans/scene0000_00'

N_POINTS = int(1.5e5)
DEVICE = 'cpu'

IGNORED_LABELS = [0]

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
scene = ScanNetScene(SCENE_DIR)

surface_points, surface_colors, point_labels = scene.create_points_colors_labels_from_pc(N_POINTS)

In [3]:
class_counts = []
for c in range(len(CLASS_NAMES.keys())):
    class_counts.append(np.sum(point_labels == c))
class_counts = np.array(class_counts)

for ign in IGNORED_LABELS:
    class_counts[ign] = 0
class_weights = class_counts / class_counts.sum()

In [4]:
def get_iou(logits, labels):
    preds = np.argmax(logits, axis=-1)
    ious = []
    
    for c in range(len(CLASS_NAMES)):
        iou = jaccard_score((labels==c).astype(int), (preds==c).astype(int), pos_label=1)
        ious.append(iou)
        
    return np.array(ious)

In [5]:
features = torch.FloatTensor(surface_colors).unsqueeze(0).to(DEVICE)
xyz = torch.FloatTensor(surface_points).unsqueeze(0)
point_labels = torch.LongTensor(point_labels).unsqueeze(0).to(DEVICE)

input_points, input_neighbors, input_pools, feat_shape = prepare_input(xyz, k=16, num_layers=3, sub_sampling_ratio=4, 
                                                           device=DEVICE)

randla = Randla(d_feature=3, d_in=8, encoder_dims=[8, 32, 64], device=DEVICE, num_class=len(CLASS_NAMES), interpolator='keops' )
randla.load_state_dict(torch.load('../../processed/saved_models/PA_baseline_scene0040_00'))
randla.eval()

Randla(
  (fc0): Conv1d(
    (conv): Conv1d(3, 8, kernel_size=(1,), stride=(1,), bias=False)
    (bn): BatchNorm1d(
      (bn): BatchNorm1d(8, eps=1e-06, momentum=0.99, affine=True, track_running_stats=True)
    )
    (activation): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (dilated_res_blocks): ModuleList(
    (0): Dilated_res_block(
      (mlp1): Conv2d(
        (conv): Conv2d(8, 4, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(
          (bn): BatchNorm2d(4, eps=1e-06, momentum=0.99, affine=True, track_running_stats=True)
        )
        (activation): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (lfa): Building_block(
        (mlp1): Conv2d(
          (conv): Conv2d(10, 4, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(
            (bn): BatchNorm2d(4, eps=1e-06, momentum=0.99, affine=True, track_running_stats=True)
          )
          (activation): LeakyReLU(negative_slope=0.2, inplace=True)
        )
  

In [6]:
with torch.no_grad():
    all_logits = randla.forward(features, input_points, input_neighbors, input_pools)
    all_logits = all_logits.squeeze().detach().numpy()

In [7]:
df = pd.DataFrame()
df['class'] = CLASS_NAMES.keys()
df['iou'] = get_iou(all_logits, point_labels.detach().squeeze().numpy())
df['weight'] = class_weights
df['iou_weighted'] = df['iou'] * df['weight']

print("mIOU:", df.iou_weighted.sum())

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


mIOU: 0.1464491298728035


In [8]:
df

Unnamed: 0,class,iou,weight,iou_weighted
0,unannotated,0.000401,0.0,0.0
1,wall,0.242413,0.24145,0.058531
2,floor,0.324117,0.257958,0.083608
3,cabinet,0.0,0.131839,0.0
4,bed,0.0,0.053305,0.0
5,chair,0.0,0.0,0.0
6,sofa,0.0,0.08959,0.0
7,table,0.078547,0.054814,0.004305
8,door,0.0,0.011451,0.0
9,window,0.000496,0.009296,5e-06


In [1]:
# 7.7% improvement noticed without Point Augment