In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
sys.path.insert(1, '/home/cem/Documents/imps/src')

import numpy as np
import open3d as o3d
import torch
from matplotlib import cm
import matplotlib.pyplot as plt
from tqdm import tqdm
import pickle

from imps.data import ScanNetScene, CLASS_NAMES
from imps.ndf.local_model import MiniNDF, NDF

SCAN_DIR = '/home/cem/Documents/datasets/ScanNet/scans'
SCENE_NAMES = ['scene0000_00']

N_POINTS = int(2e4)
RESOLUTION = 125
SIGMAS = np.array([0.09, 0.03, 0.01])
DEVICE = 'cuda'

In [None]:
scene = ScanNetScene(os.path.join(DATA_PATH, SCENE_NAME))
voxels, surface_points, surface_colors, vicinities, distances, _ = scene.create_if_data(
    RES, N_POINTS, SIGMAS, o3d_format=False, scale_sigmas=False
)


query_points = torch.FloatTensor(vicinities[s]).unsqueeze(0).cuda()
query_distances = torch.FloatTensor(distances[s]).unsqueeze(0).cuda()
voxel_grid = torch.FloatTensor(voxels[s]).unsqueeze(0).cuda()

In [None]:
model = NDF()
model = model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
model.train()

In [None]:
pbar = tqdm(range(1000))
losses = []

for e in pbar:
    optimizer.zero_grad()

    pred_dist = model.forward(query_points, voxel_grid)

    dist_criterion = torch.nn.L1Loss(reduction='none')    
    dist_loss = dist_criterion(torch.clamp(pred_dist, max=0.25), 
                               torch.clamp(query_distances, max=0.25)).sum(dim=-1).mean()

    dist_loss.backward()
    optimizer.step()

    scene_losses.append(dist_loss.item())
    
    losses.append(scene_losses)
    
    if ((e+1)%100 == 0) or (e == 0):
        pbar.set_description(f"Epoch {e+1} Loss: {round(sum(losses[-1]), 3)}")
        torch.save({'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict()}, 
                       '../../data/MiniNDF-11-scenes')

In [None]:
model.eval()
with torch.no_grad():
    encoding = model.encoder(voxel_grid)
    pred_dist = model.decoder(query_points, *encoding)

In [None]:
pts = query_points.squeeze().detach().cpu().numpy()
dst = pred_dist.squeeze().detach().cpu().numpy()

viridis = cm.get_cmap('Reds')
vic_pcd = o3d.geometry.PointCloud()
vic_pcd.points = o3d.utility.Vector3dVector(pts)
vic_pcd.colors = o3d.utility.Vector3dVector(viridis(dst)[:, :-1])
    
o3d.visualization.draw_geometries([vic_pcd])

In [None]:
torch.save(model.state_dict(), '../../data/ndf-v2')