In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
sys.path.insert(1, '/home/cem/Documents/imps/src')

import numpy as np
import open3d as o3d
import torch
from matplotlib import cm
import matplotlib.pyplot as plt
import torch
import torch.optim as optim

from imps.data import ScanNetScene, CLASS_NAMES
from imps.sqn.model import SQN
from imps.sqn.data_utils import prepare_input
from imps.ndf.local_model import MiniNDF

SCENE_DIR = '/home/cem/Documents/datasets/ScanNet/scans/scene0000_00'

N_POINTS = int(1.5e5)
RESOLUTION = 125
SIGMAS = np.array([0.5, 0.1, 0.01])

LABEL_RATIO = 0.001
N_LABEL = int(N_POINTS*LABEL_RATIO)
DEVICE = 'cuda'
IGNORED_LABELS = [0]

def get_loss(logits, labels, class_weights):
    class_weights = torch.from_numpy(class_weights).float().to(logits.device)
    logits = logits.reshape(-1, len(class_weights))
    labels = labels.reshape(-1)

    criterion = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='none')
    output_loss = criterion(logits, labels)
    output_loss = output_loss.mean()
    
    return output_loss

In [2]:
scene = ScanNetScene(SCENE_DIR)

voxel_grid, surface_points, surface_colors, vicinity_points, vicinity_distances, point_labels = scene.create_if_data(
    RESOLUTION, N_POINTS, SIGMAS
)

class_counts = []
for c in range(len(CLASS_NAMES.keys())):
    class_counts.append(np.sum(point_labels == c))
class_counts = np.array(class_counts)

for ign in IGNORED_LABELS:
    class_counts[ign] = 0
class_weights = class_counts / class_counts.sum()

query_idxs = np.random.choice(N_POINTS, N_LABEL, replace=False)
query_points = surface_points[query_idxs]
query_labels = point_labels[query_idxs]

In [3]:
ndf_model = MiniNDF(is_cuda=False)
ndf_model.load_state_dict(torch.load('../../data/ndf'))
ndf_model.eval();

voxel_input = torch.FloatTensor(voxel_grid).unsqueeze(0)
xyz = torch.FloatTensor(surface_points).unsqueeze(0)
features = torch.FloatTensor(surface_colors).unsqueeze(0)

with torch.no_grad():
    encoding = ndf_model.encoder(voxel_input)
    ndf_features = ndf_model.get_features(xyz, *encoding).permute(0,2,1).detach()

In [4]:
features = torch.cat([ndf_features, features], dim=-1).to(DEVICE)
query = torch.FloatTensor(query_points).unsqueeze(0).to(DEVICE)
query_labels = torch.LongTensor(query_labels).unsqueeze(0).to(DEVICE)

input_points, input_neighbors, input_pools = prepare_input(xyz, k=16, num_layers=3, sub_sampling_ratio=4, 
                                                           device=DEVICE)

sqn = SQN(d_feature=797, d_in=16, encoder_dims=[8, 32, 64], decoder_dims=[64, len(CLASS_NAMES)], device=DEVICE)
optimizer = optim.Adam(sqn.parameters(), lr=1e-3)

In [None]:
sqn.train()

for e in range(50):
    optimizer.zero_grad()
    
    logits = sqn.forward(features, input_points, input_neighbors, input_pools, query)
    loss = get_loss(logits, query_labels, class_weights)
    loss.backward()
    
    optimizer.step()
    
    print(f"Epoch {e+1}: {round(loss.item(), 4)}")

Epoch 1: 0.3331
Epoch 2: 0.2521
Epoch 3: 0.1972
Epoch 4: 0.159
Epoch 5: 0.132
Epoch 6: 0.1114
Epoch 7: 0.0953
Epoch 8: 0.0832
Epoch 9: 0.0732
Epoch 10: 0.065
Epoch 11: 0.0582
Epoch 12: 0.0525
Epoch 13: 0.0475
Epoch 14: 0.0432
Epoch 15: 0.0395
Epoch 16: 0.0363
Epoch 17: 0.0332
Epoch 18: 0.0306
Epoch 19: 0.0282
Epoch 20: 0.026
Epoch 21: 0.0241
Epoch 22: 0.0223
Epoch 23: 0.0206
Epoch 24: 0.0191
Epoch 25: 0.0178


In [None]:
torch.save(sqn.state_dict(), '../../data/sqn-ndf')