In [None]:
import matplotlib
import matplotlib.pyplot as plt
import math
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torch.autograd import Function
from apex import amp
from typing import Optional, Tuple, List, Union, Callable
from grid_fusion_pytorch.dataset import PointCloudDataset, CustomCollate
import time
import sys
import os
import shutil
from tqdm.auto import tqdm
from pprint import pprint
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.losses import FocalLoss

import torch_scatter

import sys

from grid_fusion_pytorch.render import render_grids, batch_fuse_to_grid, get_chunks
#from pt_svr.render import render_grids, batch_fuse_to_grid, get_chunks
from grid_fusion_pytorch.model import RefineModel

import warnings 
warnings.filterwarnings("ignore", category=UserWarning)  

In [None]:
#from https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/LovaszSoftmax/lovasz_loss.py
def lovasz_grad(gt_sorted):
    """
    Computes gradient of the Lovasz extension w.r.t sorted errors
    See Alg. 1 in paper
    """
    p = len(gt_sorted)
    gts = gt_sorted.sum()
    intersection = gts - gt_sorted.float().cumsum(0)
    union = gts + (1 - gt_sorted).float().cumsum(0)
    jaccard = 1. - intersection / union
    if p > 1:  # cover 1-pixel case
        jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
    return jaccard

# adapted from https://github.com/RaduAlexandru/lattice_net/blob/5255be27706e91b23f6e13c68c8097178cda1c34/latticenet_py/lattice/lovasz_loss.py#L23
class LovaszSoftmax(nn.Module):
    def __init__(self, reduction='mean'):
        super(LovaszSoftmax, self).__init__()
        self.reduction = reduction

    def lovasz_softmax_flat(self, inputs, targets):
        num_classes = inputs.shape[1]
        losses = []
        for c in range(num_classes):
            target_c = (targets == c).float()
            nr_pixels_gt_for_this_class=target_c.sum()
            if nr_pixels_gt_for_this_class==0:
                continue #as described in the paper, we skip the penalty for the classes that are not present in this sample
            if num_classes == 1:
                input_c = inputs[:, 0]
            else:
                input_c = inputs[:, c]
            loss_c = (torch.autograd.Variable(target_c) - input_c).abs()
            loss_c_sorted, loss_index = torch.sort(loss_c, 0, descending=True)
            target_c_sorted = target_c[loss_index]
            losses.append(torch.dot(loss_c_sorted, torch.autograd.Variable(lovasz_grad(target_c_sorted))))
        losses = torch.stack(losses)

        if self.reduction == 'none':
            loss = losses
        elif self.reduction == 'sum':
            loss = losses.sum()
        else:
            loss = losses.mean()
        return loss

    def forward(self, inputs, targets):
        inputs = F.softmax(inputs, dim=1)
        losses = self.lovasz_softmax_flat(inputs, targets)
        return losses

In [None]:
cmap_17 = torch.tensor([[161.,203,242], # cracker_box
                        [227,88,34], # sugar_box
                        [247,167,0], # mustard_bottle
                        [100,68,34], # potted_meat_can
                        [243,195,0], # banana
                        [137,44,22], # bleach_cleanser
                        [219,210,0], # mug
                        [179,68,108], # sponge
                        [41,183,0], # spatula
                        [96,78,151], # power_drill
                        [249,147,120], # wood_block
                        [0,103,166], # extra_large_clamp
                        [230,143,172], # softball
                        [0,136,85], # golf_ball
                        [132,132,130], # dice
                        [195,179,129], # toy_airplane
                        [191,0,50]])/255. # red_box

labels = ['cracker_box', 'sugar_box', 'mustard_bottle', 'potted_meat_can', 'banana', 'bleach_cleanser', 'mug', 'sponge', 'spatula', 'power_drill', 'wood_block', 'extra_large_clamp', 'softball', 'golf_ball', 'dice', 'toy_airplane', 'red_box']

In [None]:
# N x 3 ray origins and N x 3 ray_dirs
def show_rays(ray_origins, ray_dirs):
    ax = plt.figure(figsize=(24, 16)).add_subplot(projection='3d')
    ax.view_init(elev=25., azim=45.)
    _ = ax.quiver(
      ray_origins[..., 0].flatten(),
      ray_origins[..., 1].flatten(),
      ray_origins[..., 2].flatten(),
      ray_dirs[..., 0].flatten(),
      ray_dirs[..., 1].flatten(),
      ray_dirs[..., 2].flatten(), length=0.1, normalize=True, lw=0.05)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('z')
    plt.show()

# N x 3 points
def show_points(points):
    ax = plt.figure(figsize=(24, 16)).add_subplot(projection='3d')
    ax.view_init(elev=25., azim=45.)
    for i, p in enumerate(points):
        _ = ax.scatter(p[:,0], p[:,1], p[:,2], s=0.1, c=i*np.ones(len(points[0])), cmap='tab20', vmin=0, vmax=len(points))
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('z')
    plt.show()

In [None]:
MODE = 'normal'
DATASET_NAME = 'overfit_O39_S1_C32_pcd_' + MODE
NUM_CAMS = -1
BATCH_SIZE = 1
INSPECT_SHAPES = False
RANDOM_PCD = False

# create a dummy dataset and showcase dataloading
test_dataset = PointCloudDataset(root='/home/nfs/inf6/data/datasets/semantic_pcd_data/'+DATASET_NAME, split='full',
                                 num_steps=-1, num_cams=NUM_CAMS, random_pcd=RANDOM_PCD, cam_world=False)
if INSPECT_SHAPES:
    print('Length of test_dataset:', len(test_dataset), '\n')

# look at the data returned by this dataset
if INSPECT_SHAPES:
    test_output = test_dataset.__getitem__(0)
    pcd, semseg, cam_pose, depth, cam_k, gt = test_output
    print('Dataset __getitem__ output len/shapes:\n', len(pcd), semseg.shape, cam_pose.shape, depth.shape, cam_k.shape, gt.shape, '\n')
    print('pcd element shapes:')
    for item in pcd:
        print(item.shape)
    print()

# define a collate function and a dataloader
collate = CustomCollate(min_num_steps=1, max_num_steps=2)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate)
# look at the output shapes of each minibatch
if INSPECT_SHAPES:
    for pcd, semseg, cam_pose, depth, cam_k, gt in test_dataloader:
        print('Minibatch output len/shapes:\n', len(pcd), semseg.shape, cam_pose.shape, depth.shape, cam_k.shape, gt.shape)
        print('pcd batched element shapes:')
        for item in pcd:
            print(item.shape)
    print()

# example for adapting the maximum number of fusion/refinement steps
if INSPECT_SHAPES:
    test_dataloader.collate_fn.set_steps(1, 8)
    for pcd, semseg, cam_pose, depth, cam_k, gt in test_dataloader:
        print('Minibatch output len/shapes:\n', len(pcd), semseg.shape, cam_pose.shape, depth.shape, cam_k.shape, gt.shape)
        print('pcd batched element shapes:')
        for item in pcd:
            print(item.shape)

In [None]:
### positional encoding taken from https://colab.research.google.com/drive/1TppdSsLz8uKoNwqJqDGg8se8BHQcvg_K?usp=sharing#scrollTo=rrbs7YoMHAbF
class PositionalEncoder(nn.Module):
  """
  Sine-cosine positional encoder for input points.
  """
  def __init__(
    self,
    d_input: int,
    n_freqs: int,
    log_space: bool = False
  ):
    super().__init__()
    self.d_input = d_input
    self.n_freqs = n_freqs
    self.log_space = log_space
    self.d_output = d_input * (1 + 2 * self.n_freqs)
    self.embed_fns = [lambda x: x]

    # Define frequencies in either linear or log scale
    if self.log_space:
        freq_bands = 2.**torch.linspace(0., self.n_freqs - 1, self.n_freqs)
    else:
        freq_bands = torch.linspace(2.**0., 2.**(self.n_freqs - 1), self.n_freqs)

    # Alternate sin and cos
    for freq in freq_bands:
        self.embed_fns.append(lambda x, freq=freq: torch.sin(x * freq))
        self.embed_fns.append(lambda x, freq=freq: torch.cos(x * freq))
  
  def forward(self,x) -> torch.Tensor:
    """
    Apply positional encoding to input.
    """
    return torch.concat([fn(x) for fn in self.embed_fns], dim=-1)

In [None]:
# voxel grid setup
voxel_base_num = 100
voxel_grid_config = {}
voxel_grid_config['range_min'] = test_dataset.range_min
voxel_grid_config['range_max'] = test_dataset.range_max
axis_range = voxel_grid_config['range_max'][:3] - voxel_grid_config['range_min'][:3]
axis_range /= torch.min(axis_range)
voxel_grid_config['world_size'] =  (torch.ones(3)*voxel_base_num*axis_range).long() #[150,120,105] #[300,240,210]
# modify so each dim is divisible by 4
voxel_grid_config['world_size'] -= torch.remainder(voxel_grid_config['world_size'], 4)
voxel_grid_config['channels'] = 39 #test_dataset.num_classes
voxel_grid_config['voxel_size'] = (axis_range/voxel_grid_config['world_size'].float()).mean()
voxel_grid_config['anchor'] = (voxel_grid_config['range_max'][:3] + voxel_grid_config['range_min'][:3]) / 2
voxel_grid_config['density_factor'] = 10000.
voxel_grid_config['t_near'] = 0.1
voxel_grid_config['t_far'] = 2.0
# use positional encoding
voxel_grid_config['use_pos_enc'] = False
voxel_grid_config['num_freqs'] = 8
if voxel_grid_config['use_pos_enc']:
    grid_encoder = PositionalEncoder(3, voxel_grid_config['num_freqs'])
    with torch.no_grad():
        linspaces = [torch.linspace(voxel_grid_config['range_min'][i], voxel_grid_config['range_max'][i], voxel_grid_config['world_size'][i]) for i in range(3)]
        coords = torch.cartesian_prod(*linspaces).view(*voxel_grid_config['world_size'],3)
        voxel_grid_config['pos_enc'] = grid_encoder(coords).permute(3,0,1,2)
else:
    voxel_grid_config['pos_enc'] = None
    
np.save('voxel_grid_config.npy', voxel_grid_config)

print('Channels:', voxel_grid_config['channels'])
print('World size:', voxel_grid_config['world_size'])
print('Voxel size:', voxel_grid_config['voxel_size'])
print('Voxel min:', voxel_grid_config['range_min'][:3])
print('Voxel max:', voxel_grid_config['range_max'][:3])
if voxel_grid_config['use_pos_enc']:
    print('Positional encoding shape:', voxel_grid_config['pos_enc'].shape)
print()

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
dtype = torch.float32


VERBOSE = False

# renderer setup
N_RAYS = 896 #4000
N_POINTS =  192 #300
N_EPOCHS = 5000

# model setup
TRAIN_MODEL = True

#model = UNet3d(in_channels=1+voxel_grid_config['channels']+2, out_channels=1+voxel_grid_config['channels'], trilinear=False)
#model = model.to(device)
#model.train()

num_ignore_channels = 1 if not voxel_grid_config['use_pos_enc'] else 1 + voxel_grid_config['pos_enc'].shape[0]

model_kwargs = {'in_channels' : [1+voxel_grid_config['channels']+num_ignore_channels, 64, 128, 64],
                'out_channels' : [64, 128, 64, 1+voxel_grid_config['channels']],
                'mode' : 'hourglass', # 'resnet', 'hourglass'
                'norm_layer': 'GroupNorm',
                'non_lin': 'LeakyReLU',
                'num_ignore_channels': num_ignore_channels} 
np.save('model_config.npy', model_kwargs)
model = RefineModel(**model_kwargs)
#model = ResNet3d(in_channels=[1+voxel_grid_config['channels']+2,64,64,64,64,64,64,64,64],
#                 out_channels=[64,64,64,64,64,64,64,64,1+voxel_grid_config['channels']], normalize=False)
model = model.to(device)
model.train()

# training setup
LR = 0.02

WEIGHT_SEMSEG = 0.5
criterion = torch.nn.CrossEntropyLoss() # FocalLoss(mode='multiclass') # torch.nn.CrossEntropyLoss()
WEIGHT_DICE = 0.5
criterion_dice = LovaszSoftmax() # FocalLoss(mode='multiclass') # torch.nn.CrossEntropyLoss()
WEIGHT_BACKGROUND = 2. #100. #8e-1
criterion_background = torch.nn.BCEWithLogitsLoss() #torch.nn.L1Loss() # torch.nn.HuberLoss(delta=0.01)
WEIGHT_FOREGROUND = 0 # 1e-1
criterion_foreground = torch.nn.L1Loss() # torch.nn.HuberLoss(delta=0.01)
WEIGHT_DEPTH = 1. #1.
criterion_depth = torch.nn.L1Loss()
WEIGHT_DENSITY = 5e-4 # use 1e-1 to match, 5e-3 to regularize
#criterion_density = lambda x: torch.mean(torch.norm(torch.flatten(x, start_dim=1), p=1, dim=-1))
criterion_density = torch.nn.L1Loss()
optimizer = optim.AdamW(model.parameters(), lr=LR)
model, optimizer = amp.initialize(model, optimizer, opt_level='O2')

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=50)

save_grid_batch = False
if TRAIN_MODEL:
    ##### train
    best_loss = 1e25
    test_dataloader.collate_fn.set_steps(1, 1)
    progress_bar = tqdm(range(N_EPOCHS), total=N_EPOCHS, position=0, leave=True)
    patience = 0
    for epoch in progress_bar:
        for pcd, semseg, cam_pose, depth, cam_k, gt in test_dataloader:
            semseg, cam_pose, depth, cam_k = semseg.to(device), cam_pose.to(device), depth.to(device), cam_k.to(device)
            if VERBOSE:
                print('DATA LOADING')
                print('Minibatch output len/shapes\n', 'len(pcd), semseg.shape, cam_pose.shape, depth.shape, cam_k.shape')
                print(len(pcd), semseg.shape, cam_pose.shape, depth.shape, cam_k.shape)
                print()

            ##### FUSION STEP
            if VERBOSE:
                print('FUSION STEP')
            grid_batch = None
            k = 0
            for k, point_cloud_batch in enumerate(pcd):
                if VERBOSE:
                    print('Point cloud batch shape:', point_cloud_batch.shape)
                # fuse sensor data to voxel logits
                with torch.no_grad():
                    grid_batch = batch_fuse_to_grid(point_cloud_batch.to(device), grid_batch,
                                                    world_size=voxel_grid_config['world_size'],
                                                    channels=voxel_grid_config['channels'],
                                                    range_min=voxel_grid_config['range_min'],
                                                    range_max=voxel_grid_config['range_max'],
                                                    density_step=voxel_grid_config['density_factor']*voxel_grid_config['voxel_size'])        
            # switch to probabilities
            grid_batch_unrefined = torch.cat([grid_batch[:,:1], torch.exp(grid_batch[:,1:-1]), grid_batch[:,-1:]], dim=1)
            # concat positional encoding if necessary
            grid_batch_input = torch.cat([grid_batch_unrefined, 
                                          voxel_grid_config['pos_enc'].unsqueeze(0).expand(grid_batch_unrefined.shape[0],*voxel_grid_config['pos_enc'].shape).to(grid_batch_unrefined.device)],
                                          dim=1) if voxel_grid_config['use_pos_enc'] else grid_batch_unrefined
            # refine, switches to logits
            grid_batch_refined = model(grid_batch_input)
            # if necessary remove positional encoding again
            if voxel_grid_config['use_pos_enc']:
                grid_batch_refined = grid_batch_refined[:,:-voxel_grid_config['pos_enc'].shape[0]]
            if VERBOSE:
                print('Grid batch shape:', grid_batch.shape)
                print()
            # -> at this point grids is (BS, C, H, W, D)

            ##### REFINE STEP
            cam_pose_chunks, cam_k_chunks  = get_chunks(cam_pose, dim=1), get_chunks(cam_k, dim=1)
            semseg_chunks, depth_chunks = get_chunks(semseg, dim=1), get_chunks(depth, dim=1)
            loss, loss_semseg, loss_dice, loss_background, loss_foreground, loss_depth, loss_density = 0, 0, 0, 0, 0, 0, 0
            total_chunk_size = cam_pose.shape[1]
            for chunk_id in range(len(semseg_chunks)):
                chunk_size = cam_pose_chunks[chunk_id].shape[1]
                chunk_weight = chunk_size / total_chunk_size
                render, render_depth, composite_mask, gt_labels, gt_depth = render_grids(grid_batch_refined, voxel_grid_config, 
                                                                            cam_pose_chunks[chunk_id], cam_k_chunks[chunk_id], n_rays=N_RAYS,
                                                                            n_points=N_POINTS, semseg=semseg_chunks[chunk_id],
                                                                            depth=depth_chunks[chunk_id], verbose=False, hierarchical=True,
                                                                            downsample_density=True)
                
                # backprop loss between ray_gt and ray_marching
                ray_mask = gt_labels != -1
                #print(render[ray_mask.squeeze(-1)].shape, gt_labels[ray_mask].shape)
                loss_semseg += chunk_weight * criterion(render[ray_mask.squeeze(-1)], gt_labels[ray_mask])
                loss_dice += chunk_weight * criterion_dice(render[ray_mask.squeeze(-1)], gt_labels[ray_mask])
                background_mask = torch.logical_not(ray_mask).squeeze(-1)
                #print('composite mask:', composite_mask.shape)
                #print('background mask:', background_mask.shape)
                probs = torch.clamp(composite_mask,1e-3,1-1e-3)
                logits = 0.5*(torch.log(probs) - torch.log(1 - probs) + 1)
                loss_background += chunk_weight * criterion_background(logits, background_mask.float())
                #loss_background += chunk_weight * criterion_background(composite_mask[background_mask], torch.ones_like(composite_mask[background_mask]))
                loss_foreground += chunk_weight * criterion_foreground(composite_mask[ray_mask.squeeze(-1)], torch.zeros_like(composite_mask[ray_mask.squeeze(-1)]))
                #print(ray_mask.shape, render.shape, composite_mask.shape)
                if gt_depth is not None: 
                    loss_depth += chunk_weight * criterion_depth(render_depth[ray_mask.squeeze(-1)], gt_depth[ray_mask])
                else:
                    loss_depth += torch.zeros(1).to(semseg.device)
                loss_density += chunk_weight * criterion_density(grid_batch_refined[:,0], grid_batch_unrefined[:,0]) # criterion_density(grid_batch_refined[:,0])

            loss = WEIGHT_SEMSEG * loss_semseg + WEIGHT_DICE * loss_dice + WEIGHT_DEPTH * loss_depth  + WEIGHT_BACKGROUND * loss_background + WEIGHT_FOREGROUND * loss_foreground + WEIGHT_DENSITY*loss_density
            optimizer.zero_grad()
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            optimizer.step()
            scheduler.step(loss.item())
            if loss.item() < best_loss:
                best_loss = loss.item()
                torch.save(model.state_dict(), 'net_best_'+MODE+'.pt')
                patience = 0
            else:
                patience += 1
            # Show current training stats
            episode_summary = [f"{epoch+1}:"] + [f'L: {loss.item():.3f}, Seg: {loss_semseg.item():.3f}, DI: {loss_dice.item():.3f}, Z: {loss_depth.item():.3f}, BG:{loss_background.item():.3f}, D:{loss_density.item():.3f}, FG:{loss_foreground.item():.3f}, Best: {best_loss:.3f}, P: {patience}']
            # Set progress bar
            progress_bar.set_description("".join(episode_summary))
        if patience > 200:
            break
        # z=0.033