In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
from datetime import datetime

sys.path.insert(1, '/home/cem/Documents/imps/src')

import numpy as np
import matplotlib.pyplot as plt
import open3d as o3d
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import jaccard_score
from tqdm import tqdm

from imps.sqn.model import SQN
from imps.sqn.data_utils import prepare_input
from imps.metrics import compute_iou

DEVICE = 'cuda'
DATA_ROOT = '/mnt/data.nas/shareddata/6G-futurelab/synthetic_room_dataset/rooms_04'
scene = '00000004'
POS_EMBEDDING = True
LOGDIR = './logs'

scene_dir = os.path.join(DATA_ROOT, scene)

def get_loss(logits, labels, pos_weight):
    n_batch = logits.shape[0]
    logits = logits.reshape(n_batch, -1)
    labels = labels.reshape(n_batch, -1)

    criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='none')
    output_loss = criterion(logits, labels)
    output_loss = output_loss.mean()
    
    return output_loss

def pos_embed(pos, L=10):
    embs = []
    
    for l in range(L):
        sin_emb = np.sin((2^l)*np.pi*pos)
        cos_emb = np.cos((2^l)*np.pi*pos)
        embs += [sin_emb, cos_emb]
    
    return np.concatenate(embs, axis=-1).astype(np.float)

def get_data(scene_points, query_points, query_occ, embedding=False, seed=None):
    
    if embedding:
        features = torch.FloatTensor(pos_embed(scene_points)).unsqueeze(0).to(DEVICE)
    else:
        features = torch.FloatTensor(scene_points).unsqueeze(0).to(DEVICE)
        
    xyz = torch.FloatTensor(scene_points).unsqueeze(0)
    
    # This should be permutation invariant but it is not! WHY!!!
    # Hypothesis: When we permute and sub-sample, during the kNN up-sampling part, the
    # corresponding features will change.
    # We have to permute the input since it is ordered wrt. to objects.
    
    if (seed is not None) and (seed != -1):
        torch.manual_seed(seed)
    elif seed == -1:
        torch.manual_seed(31)
    
    point_perm = torch.randperm(xyz.size()[1])
    xyz = xyz[:, point_perm]
    features = features[:, point_perm]

    query = torch.FloatTensor(query_points).unsqueeze(0).to(DEVICE)
    query_labels = torch.FloatTensor(query_occ).unsqueeze(0).to(DEVICE)

    input_points, input_neighbors, input_pools = prepare_input(xyz, k=8, num_layers=3, sub_sampling_ratio=4, 
                                                           device=DEVICE)
    
    return features, input_points, input_neighbors, input_pools, query, query_labels

In [None]:
exp_name = 'occ-emb-skip-seed=31-lr=sch'

now = datetime.now()
exp_name += f'-{int(now.timestamp())}'
exp_dir = os.path.join(LOGDIR, exp_name)

scene_pcd = o3d.io.read_point_cloud(os.path.join(scene_dir, 'pointcloud0.ply'));

scene = np.load(os.path.join(scene_dir, 'pointcloud', 'pointcloud_00.npz'))
query_iou = np.load(os.path.join(scene_dir, 'points_iou', 'points_iou_00.npz'))

scene_points = scene['points']
query_points = query_iou['points']
query_occ = np.unpackbits(query_iou['occupancies'])
query_semantics = query_iou['semantics']

pos_w = np.sum(query_occ==0) / np.sum(query_occ==1)
pos_w = torch.FloatTensor([pos_w]).to(DEVICE)

query_pcd = o3d.geometry.PointCloud()
query_pcd.points = o3d.utility.Vector3dVector(query_points[query_occ==1])

sqn = SQN(d_feature=60, d_in=64, encoder_dims=[32, 64, 128], decoder_dims=[128, 32, 1], device=DEVICE,
         skip_connections=True)
optimizer = optim.Adam(sqn.parameters(), lr=1e-3)
scheduler = StepLR(optimizer, step_size=250, gamma=0.5)
writer = SummaryWriter(exp_dir)

features, input_points, input_neighbors, input_pools, query, query_labels = get_data(scene_points,
                                                                                     query_points,
                                                                                     query_occ,
                                                                                     embedding=POS_EMBEDDING,
                                                                                     seed=31)

In [None]:
writer.add_mesh('pc', input_points[0])

for e in tqdm(range(10000)):
    sqn.train()
    
    optimizer.zero_grad()
    
    logits = sqn.forward(features, input_points, input_neighbors, input_pools, query)
    loss = get_loss(logits, query_labels, pos_w)
    loss.backward()
    optimizer.step()
    
    writer.add_scalar('loss', loss, e)
    
    with torch.no_grad():
        sqn.eval()
        logits = sqn.forward(features, input_points, input_neighbors, input_pools, query)
        
        pred = logits.detach().cpu().numpy().squeeze()
        pred = (pred > 0.5).astype(np.int)
        gold = query_labels.detach().cpu().numpy().squeeze()

        occ_iou = jaccard_score(gold, pred)
        writer.add_scalar('occ-iou', occ_iou, e)
        
        if (e % 100) == 0:
            writer.add_mesh('occ-pc', query[:, pred == 1], global_step=e)
    
    if scheduler.get_last_lr()[-1] > 2e-5:
        scheduler.step()
    
    writer.add_scalar('lr', scheduler.get_last_lr()[-1], e)
    torch.save(sqn.state_dict(), os.path.join(exp_dir, 'model'))

In [None]:
# torch.save(sqn.state_dict(), '../../data/sqn-occ-emb')
sqn.load_state_dict(torch.load(os.path.join(exp_dir, 'model')));

In [None]:
features, input_points, input_neighbors, input_pools, query, query_labels = get_data(scene_points,
                                                                                     query_points,
                                                                                     query_occ, 
                                                                                     embedding=POS_EMBEDDING,
                                                                                     seed=31)

with torch.no_grad():
    sqn.eval()
    logits = sqn.forward(features, input_points, input_neighbors, input_pools, query)

    pred = logits.detach().cpu().numpy().squeeze()
    pred = (pred > 0.5).astype(np.int)
    gold = query_labels.detach().cpu().numpy().squeeze()
    
    print(jaccard_score(gold, pred))

In [None]:
occ_pts = query.cpu().squeeze().numpy()[pred == 1]
print(len(occ_pts))

occ_pcd = o3d.geometry.PointCloud()
occ_pcd.points = o3d.utility.Vector3dVector(occ_pts)

o3d.visualization.draw_geometries([occ_pcd])

In [None]:
query_points.shape

In [None]:
jaccard_score(gold, pred)

In [None]:
gold.shape