## todo: 1. kpconv training; 2. inference; 3. inference result visualization; 4. documentation


## Create dataset

In [1]:
# add package path
import sys
import os
path = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
if path not in sys.path:
    sys.path.insert(0, path)

In [2]:
# import packages
from omegaconf import OmegaConf
import pyvista as pv
import torch
import numpy as np
from tqdm.auto import tqdm
import time
import laspy

from rock_detection_3d.datasets.segmentation.rock_las import RockLASDataset




In [3]:
# configure visualization
os.environ["DISPLAY"] = ":1.0"
os.environ["PYVISTA_OFF_SCREEN"]="true"
os.environ["PYVISTA_PLOT_THEME"]="true"
os.environ["PYVISTA_USE_PANEL"]="true"
os.environ["PYVISTA_AUTO_CLOSE"]="false"
os.system("Xvfb :1 -screen 0 1024x768x24 > /dev/null 2>&1 &")

0

In [4]:
# configure dataset params

DIR = "" # Replace with your root directory, the data will go in DIR/data.
USE_COLOR = True #@param {type:"boolean"}

In [5]:
pbr_yaml = """
class: None # shapenet.ShapeNetDataset
task: segmentation
dataroot: %s
color: %r                                     # Use color vectors as features
first_subsampling: 0.02                       # Grid size of the input data
pre_transforms:                               # Offline transforms, done only once        
    - transform: GridSampling3D
      params:
        size: ${first_subsampling}
train_transforms:                             # Data augmentation pipeline
    - transform: RandomNoise
      params:
        sigma: 0.01
        clip: 0.05
    - transform: RandomScaleAnisotropic
      params:
        scales: [0.9,1.1]
    - transform: AddOnes
    - transform: AddFeatsByKeys
      params:
        list_add_to_x: [True]
        feat_names: ["ones"]
        delete_feats: [True]
test_transforms:
    - transform: AddOnes
    - transform: AddFeatsByKeys
      params:
        list_add_to_x: [True]
        feat_names: ["ones"]
        delete_feats: [True]
""" % (os.path.join(DIR,"data"), USE_COLOR) 

from omegaconf import OmegaConf
params = OmegaConf.create(pbr_yaml)


In [6]:
# create dataset
dataset = RockLASDataset(params)
dataset

Dataset: RockLASDataset 
[0;95mtrain_pre_batch_collate_transform [0m= None
[0;95mval_pre_batch_collate_transform [0m= None
[0;95mtest_pre_batch_collate_transform [0m= None
[0;95mpre_transform [0m= Compose([
    GridSampling3D(grid_size=0.02, quantize_coords=False, mode=mean),
])
[0;95mtest_transform [0m= Compose([
    AddOnes(),
    AddFeatsByKeys(ones=True),
])
[0;95mtrain_transform [0m= Compose([
    RandomNoise(sigma=0.01, clip=0.05),
    RandomScaleAnisotropic([0.9, 1.1]),
    AddOnes(),
    AddFeatsByKeys(ones=True),
])
[0;95mval_transform [0m= None
[0;95minference_transform [0m= Compose([
    GridSampling3D(grid_size=0.02, quantize_coords=False, mode=mean),
    AddOnes(),
    AddFeatsByKeys(ones=True),
])
Size of [0;95mtrain_dataset [0m= 29
Size of [0;95mtest_dataset [0m= 10
Size of [0;95mval_dataset [0m= 10
[0;95mBatch size =[0m None

In [7]:
# visually inspect dataset 

#@title Plot samples with part annotations { run: "auto" }
objectid_1 = 3 #@param {type:"slider", min:0, max:100, step:1}
objectid_2 = 4 #@param {type:"slider", min:0, max:100, step:1}
objectid_3 = 5 #@param {type:"slider", min:0, max:100, step:1}

samples = [objectid_1,objectid_2,objectid_3]
p = pv.Plotter(notebook=True,shape=(1, len(samples)),window_size=[1024,412])
for i in range(len(samples)):
    p.subplot(0, i)
    sample = dataset.train_dataset[samples[i]]
    point_cloud = pv.PolyData(sample.pos.numpy())
    point_cloud['y'] = sample.y.numpy()
    p.add_points(point_cloud,  show_scalar_bar=False, point_size=4)
    p.camera_position = [-1,5, -10]
p.show()

ViewInteractiveWidget(height=412, layout=Layout(height='auto', width='100%'), width=1024)

## Create segmentation model

In [8]:
# import packages
from torch_points3d.applications.kpconv import KPConv

In [9]:
# create KPConv model

color = 3  # use RGB data

class SegKPConv(torch.nn.Module):
    def __init__(self, cat_to_seg):
        super().__init__()
        self.unet = KPConv(
            architecture = "unet", 
            input_nc = color, 
            output_nc = 2,  # isPBR & notPBR
            num_layers= 4, 
            in_grid_size = params['first_subsampling'],  # grid size at the entry of the network; should be consistent of dataset first sampling resolution
            )
    
    @property
    def conv_type(self):
        """ This is needed by the dataset to infer which batch collate should be used"""
        return self.unet.conv_type
    
    def get_batch(self):
        return self.batch
    
    def get_output(self):
        """ This is needed by the tracker to get access to the ouputs of the network"""
        return self.output
    
    def get_labels(self):
        """ Needed by the tracker in order to access ground truth labels"""
        return self.labels
    
    
    def get_current_losses(self):
        """ Entry point for the tracker to grab the loss """
        return {"loss_seg": float(self.loss_seg)}

    def forward(self, data):
        self.labels = data.y
        self.batch = data.batch
        
        # Forward through unet and classifier
        output_batch = self.unet(data)
        
        self.output = output_batch.x
        #print(self.output)
        #print(self.labels)
        
        # Set loss for the backward pass
        self.loss_seg = torch.nn.functional.nll_loss(self.output, self.labels)
        return self.output

    def get_spatial_ops(self):
        return self.unet.get_spatial_ops()
        
    def backward(self):
         self.loss_seg.backward() 


model = SegKPConv(dataset.class_to_segments)
model

SegKPConv(
  (unet): KPConvUnet(
    (down_modules): ModuleList(
      (0): KPDualBlock(
        Nb parameters: 34304
        (blocks): ModuleList(
          (0): SimpleBlock(
            Nb parameters: 3968; None; RadiusNeighbourFinder {'_radius': 0.05, '_max_num_neighbors': 25, '_conv_type': 'partial_dense'}
            (kp_conv): KPConvLayer(InF: 4, OutF: 64, kernel_pts: 15, radius: 0.03, KP_influence: linear, Add_one: False)
            (bn): FastBatchNorm1d(
              (batch_norm): BatchNorm1d(64, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True)
            )
            (activation): LeakyReLU(negative_slope=0.1)
          )
          (1): ResnetBBlock(
            Nb parameters: 30336
            (kp_conv): SimpleBlock(
              Nb parameters: 15424; None; RadiusNeighbourFinder {'_radius': 0.05, '_max_num_neighbors': 25, '_conv_type': 'partial_dense'}
              (kp_conv): KPConvLayer(InF: 32, OutF: 32, kernel_pts: 15, radius: 0.03, KP_influence: line

## The data loaders and CPU pre computing features
KPConv is quite demanding on spatial operations such as grid sampling and radius search. On the network loaded here we have 10 KPConv layers on the encoder which means 10 radius search operations with varying number of neighbours. We observed a significant performance gain by moving those operations to the CPU where they can easily be optimised with suitable data structures such as kd-tree. We use [nonaflann](https://github.com/jlblancoc/nanoflann) in the back-end, a 3D optimised kd-tree implementation. Note that this is beneficiary only if you have access to multiple CPU threads.

You can decide to precompute those spatial operations by setting the `precompute_multi_scale` parameter to `True` when creating the data loaders. The dataset will mine the model to figure out which spatial operations are required and in which order.

In [10]:
NUM_WORKERS = 2  # 4
BATCH_SIZE = 2  # 16
dataset.create_dataloaders(
    model,
    batch_size=BATCH_SIZE, 
    num_workers=NUM_WORKERS, 
    shuffle=True, 
    precompute_multi_scale=True 
    )

In [11]:
sample = next(iter(dataset.train_dataloader))
sample.keys

['x',
 'y',
 'pos',
 'multiscale',
 'upsample',
 'batch',
 'category',
 'center',
 'file_name',
 'grid_size',
 'id_scan',
 'origin_id',
 'ptr',
 'scale']

Our `sample` contains the pre computed spatial information in the `multiscale` (encoder side) and `upsample` (decoder) attrivutes. The decoder pre computing is quite simple and just involves some basic caching for the nearest neighbour interpolation operation. Let's take a look at the encoder side of things first. 

In [12]:
sample.multiscale

[Batch(batch=[17482], idx_neighboors=[17482, 25], pos=[17482, 3]),
 Batch(batch=[17482], idx_neighboors=[17482, 25], pos=[17482, 3]),
 Batch(batch=[6995], grid_size=[2], idx_neighboors=[6995, 25], pos=[6995, 3]),
 Batch(batch=[6995], idx_neighboors=[6995, 25], pos=[6995, 3]),
 Batch(batch=[1706], grid_size=[2], idx_neighboors=[1706, 25], pos=[1706, 3]),
 Batch(batch=[1706], idx_neighboors=[1706, 25], pos=[1706, 3]),
 Batch(batch=[404], grid_size=[2], idx_neighboors=[404, 25], pos=[404, 3]),
 Batch(batch=[404], idx_neighboors=[404, 25], pos=[404, 3]),
 Batch(batch=[116], grid_size=[2], idx_neighboors=[116, 25], pos=[116, 3]),
 Batch(batch=[116], idx_neighboors=[116, 25], pos=[116, 3])]

`sample.multiscale` contains 10 different versions of the input batch, each one of these versions contains the location of the points in `pos` as well as the neighbours of these points in the previous point cloud. We will first look at the points coming out of each downsampling layer (strided convolution), we have 5 of them.

In [13]:
#@title Successive downsampling {run:"auto"}
sample_in_batch = 0 #@param {type:"slider", min:0, max:5, step:1}
ms_data = sample.multiscale 
num_downsize = int(len(ms_data) / 2)
p = pv.Plotter(notebook=True,shape=(1, num_downsize),window_size=[1024,256])
for i in range(0,num_downsize):
    p.subplot(0, i)
    pos = ms_data[2*i].pos[ms_data[2*i].batch == sample_in_batch].numpy()
    point_cloud = pv.PolyData(pos)
    point_cloud['y'] = pos[:,1]
    p.add_points(point_cloud,  show_scalar_bar=False, point_size=3)
    p.add_text("Layer {}".format(i+1),font_size=10)
    p.camera_position = [-1,5, -10]
p.show()

ViewInteractiveWidget(height=256, layout=Layout(height='auto', width='100%'), width=1024)

Let's now take one point in a layer (query point) and show its neighbours in the previous layer (support point)

In [14]:
#@title Explore Neighborhood {run: "auto"}
selected_layer = 7 #@param {type:"slider", min:1, max:9, step:1}
sample_in_batch = 0 #@param {type:"slider", min:0, max:5, step:1}
point1_id = 3 #@param {type:"slider", min:0, max:600, step:1}
point2_id =  8 #@param {type:"slider", min:0, max:600, step:1}

p = pv.Plotter(notebook=True,shape=(1, 2),window_size=[1024,412])

# Selected layer
p.subplot(0, 1)
ms_data = sample.multiscale[selected_layer]
pos = ms_data.pos[ms_data.batch == sample_in_batch].numpy()
nei = ms_data.idx_neighboors[ms_data.batch == sample_in_batch]
point_cloud = pv.PolyData(pos)
p.add_points(point_cloud,  show_scalar_bar=False, point_size=3,opacity=0.3)
p.add_points(pos[point1_id,:],  show_scalar_bar=False, point_size=7.0,color='red')
p.add_points(pos[point2_id,:],  show_scalar_bar=False, point_size=7.0,color='green')
p.camera_position = [-1,5, -10]

# Previous layer
p.subplot(0, 0)
ms_data = sample.multiscale[selected_layer-1]
pos = ms_data.pos[ms_data.batch == sample_in_batch].numpy()
point_cloud = pv.PolyData(pos)
p.add_points(point_cloud,  show_scalar_bar=False,point_size=3, opacity=0.3)
nei_pos = ms_data.pos[nei[point1_id]].numpy()
nei_pos = nei_pos[nei[point1_id] >= 0]
p.add_points(nei_pos,  show_scalar_bar=False, point_size=3.0,color='red')
nei_pos = ms_data.pos[nei[point2_id]].numpy()
nei_pos = nei_pos[nei[point2_id] >= 0]
p.add_points(nei_pos,  show_scalar_bar=False, point_size=3.0,color='green')
p.camera_position = [-1,5, -10]

p.show()

ViewInteractiveWidget(height=412, layout=Layout(height='auto', width='100%'), width=1024)

## Train neural network

In [15]:
class Trainer:
    def __init__(self, model, dataset, num_epoch = 60, device=torch.device('cuda'), checkpoint_path="model/kpconv"):
        self.num_epoch = num_epoch
        self._model = model
        self._dataset=dataset
        self.device = device
        self.checkpoint_path = checkpoint_path
        if not os.path.exists(self.checkpoint_path):
            os.makedirs(self.checkpoint_path)
        
    def save_model(self, epoch):
        f = os.path.join(self.checkpoint_path, "{epoch}.pt".format(epoch=epoch))
        torch.save({
            'epoch': epoch,
            'model_state_dict': self._model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'loss': self._model.loss_seg,
            }, f)
        
    def load_model(self, epoch):
        f = os.path.join(self.checkpoint_path, "{epoch}.pt".format(epoch=epoch))
        assert os.path.isfile(f)
        checkpoint = torch.load(f)
        self._model.load_state_dict(checkpoint['model_state_dict'])
        self._model.to(self.device)
        self._model.eval()
        return self._model
        
    def resume_model(self, epoch):
        f = os.path.join(self.checkpoint_path, "{epoch}.pt".format(epoch=epoch))
        assert os.path.isfile(f)
        self.optimizer = torch.optim.Adam(self._model.parameters(), lr=0.001)
        self.tracker = self._dataset.get_tracker(False, True)
        
        checkpoint = torch.load(f)
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self._model.load_state_dict(checkpoint['model_state_dict'])
        self._model.to(self.device)
        

    def fit(self):
        self.optimizer = torch.optim.Adam(self._model.parameters(), lr=0.001)
        self.tracker = self._dataset.get_tracker(False, True)

        for i in range(self.num_epoch):
            print("=========== EPOCH %i ===========" % i)
            time.sleep(0.5)
            self.train_epoch()
            self.tracker.publish(i)
            self.valid_epoch()
            self.tracker.publish(i)
            self.save_model(i)

    def train_epoch(self):
        self._model.to(self.device)
        self._model.train()
        self.tracker.reset("train")
        train_loader = self._dataset.train_dataloader
        iter_data_time = time.time()
        with tqdm(train_loader) as tq_train_loader:
            for i, data in enumerate(tq_train_loader):
                t_data = time.time() - iter_data_time
                iter_start_time = time.time()
                self.optimizer.zero_grad()
                data.to(self.device)
                self._model.forward(data)
                self._model.backward()
                self.optimizer.step()
                if i % 10 == 0:
                    self.tracker.track(self._model)

                tq_train_loader.set_postfix(
                    **self.tracker.get_metrics(),
                    data_loading=float(t_data),
                    iteration=float(time.time() - iter_start_time),
                )
                iter_data_time = time.time()

    def valid_epoch(self):
        self._model.to(self.device)
        self._model.eval()
        self.tracker.reset("val")
        val_loader = self._dataset.val_dataloader
        iter_data_time = time.time()
        with tqdm(val_loader) as tq_val_loader:
            for i, data in enumerate(tq_val_loader):
                t_data = time.time() - iter_data_time
                iter_start_time = time.time()
                data.to(self.device)
                self._model.forward(data)           
                self.tracker.track(self._model)

                tq_val_loader.set_postfix(
                    **self.tracker.get_metrics(),
                    data_loading=float(t_data),
                    iteration=float(time.time() - iter_start_time),
                )
                iter_data_time = time.time()
                
    def test(self, save=True):
        self._model.to(self.device)
        self._model.eval()
        self.tracker.reset("test")
        test_loader = self._dataset.test_dataloaders[0]
        iter_data_time = time.time()
        with tqdm(test_loader) as tq_test_loader:
            for i, data in enumerate(tq_test_loader):
                t_data = time.time() - iter_data_time
                iter_start_time = time.time()
                data.to(self.device)
                self._model.forward(data)           
                self.tracker.track(self._model)

                tq_test_loader.set_postfix(
                    **self.tracker.get_metrics(),
                    data_loading=float(t_data),
                    iteration=float(time.time() - iter_start_time),
                )
                if save:
                    self.save_las(data)
                
                iter_data_time = time.time()
                
                
        self.tracker.publish(self.num_epoch - 1)
        if save:
            print("saved prediction las files")
        
         
        
    def save_las(self, batch_data):
        pred = self._model.get_output()
        outputs = torch.argmax(pred, 1)

        for i in torch.unique(batch_data.batch):
            idx = batch_data.batch==i
            pos = batch_data.pos[idx]
            x = batch_data.x[idx]
            y = outputs[idx]
            origin_id = batch_data.origin_id[idx]
            id_scan = batch_data.id_scan[i]
            grid_size = batch_data.grid_size[i]
            category = batch_data.category[idx]
            center = batch_data.center[i*3:i*3+3]
            scale = batch_data.scale[i]
            file_name = batch_data.file_name[i]
            
            self.write_las(file_name, pos, x, y, center, scale)

            
    def write_las(self, file_name, pos, x, y, center, scale):
        path = 'data/rocklas/prediction'
        if not os.path.exists('data/rocklas/prediction'):
            os.makedirs(self.checkpoint_path)
        pos = pos / scale
        pos = pos + center
        pos = pos.cpu().detach().numpy()
        color = (x * (2**16)).cpu().detach().numpy().astype(np.uint16)
        PBR_ids = (y==1).cpu().detach().numpy()
        notPBR_ids = (y==0).cpu().detach().numpy()
        isPBR = np.empty(pos.shape[0])
        isPBR[:] = np.NaN
        notPBR = isPBR.copy()
        isPBR[PBR_ids] = 0
        notPBR[notPBR_ids] = 1

        f = os.path.join(path, "pred_{f}.las".format(f=file_name))
        header = laspy.LasHeader(point_format=2, version="1.2")
        header.scales = np.array([0.01, 0.01, 0.01])
        header.add_extra_dim(laspy.ExtraBytesParams(name="isPBR", type=np.float64))
        header.add_extra_dim(laspy.ExtraBytesParams(name="notPBR", type=np.float64))

        las = laspy.LasData(header)
        las.x = pos[:, 0]
        las.y = pos[:, 1]
        las.z = pos[:, 2]
        las.red = color[:, 0]
        las.green =  color[:, 1]
        las.blue = color[:, 2]
        las.isPBR = isPBR
        las.notPBR = notPBR

        las.write(f)

In [16]:
trainer = Trainer(model, dataset)

In [None]:
trainer.fit()



  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]



  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]



  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]



  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]



  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]



  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]



  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]



  0%|          | 0/15 [00:00<?, ?it/s]

In [None]:
%load_ext tensorboard
%tensorboard --logdir tensorboard/ # Change for your log location

## Inference

In [None]:
#trainer.resume_model(49)

In [None]:
trainer.test()

In [None]:
from rock_detection_3d.utils.las_reader import Read_Las_from_Path, Read_Las_from_Json

In [None]:
pred_las_reader = Read_Las_from_Path('data/rocklas/prediction')
print(len(pred_las_reader))

In [None]:
las_reader = Read_Las_from_Json('data/rocklas/raw/test_split.json')
print(len(las_reader))

In [None]:
idx = 6

p = pv.Plotter(notebook=True,shape=(1, 2),window_size=[1024,412])

p.subplot(0, 0)
pos, color, y = pred_las_reader.get_normalized(idx)
point_cloud = pv.PolyData(pos)
point_cloud['y'] = y
p.add_points(point_cloud,  show_scalar_bar=False, point_size=4)
p.camera_position = [-1,5, -10]
               
p.subplot(0, 1)
pos, color, y = las_reader.get_normalized(idx)
point_cloud = pv.PolyData(pos)
point_cloud['y'] = y
p.add_points(point_cloud,  show_scalar_bar=False, point_size=4)
p.camera_position = [-1,5, -10]          
               
p.show()