In [1]:
# This notebook is trying to train Point Augment by averaging out the class labels obtained from a randla decoder.

In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
import time
import numpy as np
import torch
import torch
import torch.optim as optim

from imps.data.scannet import ScanNetScene, CLASS_NAMES, MOVABLE_INSTANCE_NAMES
from imps.sqn.model import Randla
from imps.sqn.data_utils import prepare_input
from imps.point_augment.Common import loss_utils, point_augment_utils

from imps.point_augment.Common import loss_utils
from imps.point_augment.Augmentor.augmentor import Augmentor

SCENE_DIR = '/app/mnt/scans/scene0040_00'

N_POINTS = int(1.5e5)
DEVICE = 'cuda'

IGNORED_LABELS = [0]



Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
scene = ScanNetScene(SCENE_DIR)
surface_points, surface_colors, point_labels, point_instances, instance_number = scene.create_points_colors_labels_from_pc(N_POINTS)

In [3]:
center = surface_points.mean(axis=0, keepdims=True)
surface_points -= center
    
pts_min = surface_points.min(axis=0, keepdims=True)
pts_max = surface_points.max(axis=0, keepdims=True)
surface_points = (surface_points - pts_min) / (pts_max - pts_min)
surface_points -= 0.5

In [4]:
seg2idx = {c: i for i,c in enumerate(CLASS_NAMES)}
movable_instances = dict((k, seg2idx[k]) for k in list(set(MOVABLE_INSTANCE_NAMES)))

In [15]:
print(seg2idx)

{'unannotated': 0, 'wall': 1, 'floor': 2, 'cabinet': 3, 'bed': 4, 'chair': 5, 'sofa': 6, 'table': 7, 'door': 8, 'window': 9, 'bookshelf': 10, 'picture': 11, 'counter': 12, 'desk': 13, 'curtain': 14, 'refrigerator': 15, 'showercurtain': 16, 'toilet': 17, 'sink': 18, 'bathtub': 19, 'otherfurniture': 20}


In [5]:
label_index_dict = {} 
for index, name in enumerate(list(seg2idx.keys())):
    res_list = np.where(point_labels == index) # extracting indices corresponding to a label
    label_index_dict[name] = res_list

{'unannotated': (array([     3,      8,     10, ..., 149985, 149986, 149993]),), 'wall': (array([     2,      4,     12, ..., 149983, 149995, 149996]),), 'floor': (array([     6,      7,     14, ..., 149994, 149997, 149998]),), 'cabinet': (array([    24,     42,    132, ..., 149926, 149933, 149988]),), 'bed': (array([], dtype=int64),), 'chair': (array([   149,    202,    206, ..., 149929, 149967, 149984]),), 'sofa': (array([    16,    119,    166, ..., 149938, 149943, 149987]),), 'table': (array([   560,    806,    836, ..., 149757, 149935, 149966]),), 'door': (array([    34,     66,    136, ..., 149950, 149969, 149980]),), 'window': (array([    75,    112,    121, ..., 149527, 149885, 149888]),), 'bookshelf': (array([    11,     19,     23, ..., 149975, 149982, 149992]),), 'picture': (array([], dtype=int64),), 'counter': (array([], dtype=int64),), 'desk': (array([     0,      1,      5, ..., 149990, 149991, 149999]),), 'curtain': (array([], dtype=int64),), 'refrigerator': (array([], d

In [6]:
# needed for calculation of randla_loss

class_counts = []
for c in range(len(CLASS_NAMES.keys())):
    class_counts.append(np.sum(point_labels == c))
class_counts = np.array(class_counts)

for ign in IGNORED_LABELS:
    class_counts[ign] = 0
class_weights = class_counts / class_counts.sum()

In [7]:
# excluding out point_mask for now.

def get_randla_loss(logits, labels, class_weights):
    class_weights = torch.from_numpy(class_weights).float().to(logits.device)
    logits = logits.reshape(-1, len(class_weights))
    labels = labels.reshape(-1)

    criterion = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='none')
    output_loss = criterion(logits, labels)
    
   
    n_points = output_loss.shape[0]
        
    output_loss = output_loss.sum() / n_points
    
    return output_loss

In [8]:
object_id_instances_list, original_instance_indices, instance_coordinates = point_augment_utils.get_instance_indices_dict(seg2idx, movable_instances, surface_points, point_labels, point_instances)

In [9]:
features = torch.FloatTensor(surface_colors).unsqueeze(0).to(DEVICE)
xyz = torch.FloatTensor(surface_points).unsqueeze(0)
point_labels = torch.LongTensor(point_labels).unsqueeze(0).to(DEVICE)

In [10]:
# initializing colors, surface_points and labels as tensors. Prepare the input for neural net operation. Initialize RandLA net

input_points, input_neighbors, input_pools, feat_shape = prepare_input(xyz, k=16, num_layers=3, encoder_dims = [8,32,64], sub_sampling_ratio=4, 
                                                           device=DEVICE)


In [11]:
randla = Randla(d_feature=3, d_in=8, encoder_dims=[8, 32, 64], device=DEVICE, num_class=len(CLASS_NAMES),interpolator='keops')

In [12]:
# initialize epoch, dimension, augmentor, optimizers
dim = 3
augmentor = Augmentor().cuda()
optimizer_r = optim.Adam(randla.parameters(), lr=1e-3) # optimizer for randla 
optimizer_a = torch.optim.Adam(                        # optimzer for PA
            augmentor.parameters(),
            lr=0.001,
            betas=(0.9, 0.999),
            eps=1e-08,
        )

In [14]:
logits = randla.forward(features, input_points, input_neighbors, input_pools) # had to change it for randla, was modified according to sqn
output_label_dict = point_augment_utils.averaging_labels_from_randla_decoder(label_index_dict, logits, seg2idx)


In [15]:
print(output_label_dict)

{0: tensor([ 1.1515,  0.8761,  1.0159,  0.0380,  0.4630,  1.2239,  0.5455,  0.2650,
         0.0143,  0.1882,  0.6193,  0.0129,  0.3107,  0.1975,  0.5618,  0.4228,
        -0.0173,  0.1909,  0.1468,  0.7037,  0.2597], device='cuda:0',
       grad_fn=<MeanBackward1>), 1: tensor([ 1.0039,  0.8684,  1.0846, -0.0073,  0.4354,  1.0456,  0.5352,  0.2176,
         0.1193,  0.2902,  0.5908, -0.0535,  0.3238,  0.1393,  0.4501,  0.2706,
         0.0581,  0.2319,  0.1521,  0.7485,  0.3174], device='cuda:0',
       grad_fn=<MeanBackward1>), 2: tensor([ 0.4962, -0.0302,  0.5091,  0.2253, -0.0712,  0.5387, -0.1200,  0.0889,
        -0.0736, -0.1281,  0.0538,  0.1933,  0.4112,  0.0708,  0.6105,  0.0504,
        -0.0090, -0.1212,  0.4091,  0.5621,  0.0886], device='cuda:0',
       grad_fn=<MeanBackward1>), 3: tensor([ 1.2378,  0.4813,  1.2310,  0.0153,  0.2130,  1.0912,  0.2269,  0.0127,
         0.0924,  0.0869,  0.3240, -0.0102,  0.2859, -0.0320,  0.5548, -0.0310,
        -0.0225, -0.0611,  0.3809, 

In [16]:
randla.train()
augmentor=augmentor.train()

for epoch in range(0, 100):  
    # Original Point Cloud operation: 
    start_step1 = time.time()
    
    logits = randla.forward(features, input_points, input_neighbors, input_pools) # had to change it for randla, was modified according to sqn
    output_label_dict = averaging_labels_from_randla_decoder(label_index_dict, logits, len_instances = len(CLASS_NAMES))
    table_logits = output_label_dict[7]

    
    
    # Noise variable added to generate feature matrix 
    optimizer_a.zero_grad()
    optimizer_r.zero_grad()
    for key,value in instance_coordinates.items():
        for index, coord in enumerate(value):
            coord = torch.FloatTensor(coord).unsqueeze(0).to(DEVICE)
            coord=coord.transpose(2,1).contiguous()
            noise = 0.02 * torch.randn(1, 1024).cuda()
            aug_instance = augmentor(coord, noise)
            augmented_scene = point_augment_utils.add_aug_instance_to_pc(aug_instance, surface_points, original_instance_indices[key][index])

    augmented_scene = torch.FloatTensor(augmented_scene).unsqueeze(0)
    
    # Augmented Point Cloud Operation:
    
    # DOUBT:# I have to do this. As augmented point cloud is generated during each epoch and I need to prepare input for it. Any solution?
    input_points_aug, input_neighbors_aug, input_pools_aug, feat_shape = prepare_input(augmented_scene, k=16, num_layers=3, sub_sampling_ratio=4, 
                                                           device=DEVICE)  



    logits_aug = randla.forward(features, input_points_aug, input_neighbors_aug, input_pools_aug)
    output_label_dict_aug = averaging_labels_from_randla_decoder(label_index_dict, logits_aug, len_instances = len(CLASS_NAMES))
    aug_table_logits = output_label_dict_aug[7] # calculate the logits for the augmented table

  
    
                     
    augLoss  = loss_utils.aug_loss(table_logits.unsqueeze(0), aug_table_logits.unsqueeze(0))  
    randlaLoss = get_randla_loss(logits_aug, point_labels, class_weights) #how do I calculate randla_loss here? Do I calculate loss between logits_augmented and logits true?


    augLoss.backward(retain_graph=True)
    randlaLoss.backward(retain_graph=True)


    optimizer_r.step()
    optimizer_a.step()
    start_step2 = time.time()
    
    print(f"Epoch {epoch+1}, Augmentor loss:{round(augLoss.item(), 4)}, RandlaNet loss:{round(randlaLoss.item(), 4)}, Epoch_duration: {round((start_step2-start_step1),2)} s")



Epoch 1, Augmentor loss:4.2178, RandlaNet loss:0.4771, Epoch_duration: 1.83 s
Epoch 2, Augmentor loss:3.7484, RandlaNet loss:0.4749, Epoch_duration: 1.81 s
Epoch 3, Augmentor loss:3.2046, RandlaNet loss:0.4707, Epoch_duration: 1.8 s
Epoch 4, Augmentor loss:2.5104, RandlaNet loss:0.4671, Epoch_duration: 1.78 s
Epoch 5, Augmentor loss:2.18, RandlaNet loss:0.4634, Epoch_duration: 1.79 s
Epoch 6, Augmentor loss:2.1126, RandlaNet loss:0.4604, Epoch_duration: 1.8 s
Epoch 7, Augmentor loss:2.0858, RandlaNet loss:0.4559, Epoch_duration: 1.81 s
Epoch 8, Augmentor loss:2.0582, RandlaNet loss:0.4518, Epoch_duration: 1.81 s
Epoch 9, Augmentor loss:2.0287, RandlaNet loss:0.4477, Epoch_duration: 1.78 s
Epoch 10, Augmentor loss:2.0063, RandlaNet loss:0.4445, Epoch_duration: 1.77 s
Epoch 11, Augmentor loss:1.9966, RandlaNet loss:0.4416, Epoch_duration: 1.82 s
Epoch 12, Augmentor loss:1.994, RandlaNet loss:0.439, Epoch_duration: 1.77 s
Epoch 13, Augmentor loss:1.9845, RandlaNet loss:0.4363, Epoch_durat

In [17]:
torch.save(randla.state_dict(), '../../processed/saved_models/PA_baseline_scene0040_00')