In [5]:
%load_ext autoreload
%autoreload 2

import os
import sys
sys.path.insert(1, '/home/cem/Documents/imps/src')

import numpy as np
import open3d as o3d
import torch
from matplotlib import cm
import matplotlib.pyplot as plt
import torch
import torch.optim as optim

from imps.data import ScanNetScene, CLASS_NAMES
from imps.sqn.model import SQN
from imps.sqn.data_utils import prepare_input

SCENE_DIR = '/home/cem/Documents/datasets/ScanNet/scans/scene0000_00'
N_POINTS = int(1.5e5)
# Not important we are not using this here yet. Keep this small for quick data processing
RESOLUTION = 25
SIGMAS = None
LABEL_RATIO = 0.0001
N_LABEL = int(N_POINTS*LABEL_RATIO)
DEVICE = 'cuda'

IGNORED_LABELS = [0]

def get_loss(logits, labels, class_weights):
    class_weights = torch.from_numpy(class_weights).float().to(logits.device)
    logits = logits.reshape(-1, len(class_weights))
    labels = labels.reshape(-1)

    criterion = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='none')
    output_loss = criterion(logits, labels)
    output_loss = output_loss.mean()
    
    return output_loss

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [21]:
scene.mesh.visual.face_colors.shape

(153587, 4)

In [20]:
np.array(mesh.vertices).shape

(81369, 3)

In [None]:
mesh = o3d.geometry.TriangleMesh(vertices=o3d.utility.Vector3dVector(scene.mesh.vertices), 
                                 triangles=o3d.utility.Vector3iVector(scene.mesh.faces))
mesh.vertex_colors = o3d.utility.Vector3dVector(scene.mesh.visual.vertex_colors[:, :-1] / 255)

o3d.visualization.draw_geometries([mesh]) 

In [None]:
scene = ScanNetScene(SCENE_DIR)

voxel_grid, surface_points, surface_colors, vicinity_points, vicinity_distances, point_labels = scene.create_if_data(
    RESOLUTION, N_POINTS, SIGMAS
)

class_counts = []
for c in range(len(CLASS_NAMES.keys())):
    class_counts.append(np.sum(point_labels == c))
class_counts = np.array(class_counts)

for ign in IGNORED_LABELS:
    class_counts[ign] = 0
class_weights = class_counts / class_counts.sum()

query_idxs = np.random.choice(N_POINTS, N_LABEL, replace=False)
query_points = surface_points[query_idxs]
query_labels = point_labels[query_idxs]

In [None]:
features = torch.FloatTensor(surface_colors).unsqueeze(0).to(DEVICE)
xyz = torch.FloatTensor(surface_points).unsqueeze(0)
query = torch.FloatTensor(query_points).unsqueeze(0).to(DEVICE)
query_labels = torch.LongTensor(query_labels).unsqueeze(0).to(DEVICE)

input_points, input_neighbors, input_pools = prepare_input(xyz, k=16, num_layers=3, sub_sampling_ratio=4, 
                                                           device=DEVICE)

sqn = SQN(d_feature=3, d_in=8, encoder_dims=[8, 32, 64], decoder_dims=[64, len(CLASS_NAMES)], device=DEVICE)
optimizer = optim.Adam(sqn.parameters(), lr=1e-3)

In [None]:
sqn.train()

for e in range(50):
    optimizer.zero_grad()
    
    logits = sqn.forward(features, input_points, input_neighbors, input_pools, query)
    loss = get_loss(logits, query_labels, class_weights)
    loss.backward()
    
    optimizer.step()
    
    print(f"Epoch {e+1}: {round(loss.item(), 4)}")

In [None]:
torch.save(sqn.state_dict(), '../../data/sqn-0.0001')