In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
import time
import numpy as np
import torch
import torch.optim as optim

from imps.data.scannet import ScanNetScene, CLASS_NAMES
from imps.sqn.model import Randla
from imps.sqn.data_utils import prepare_input

from imps.point_augment.Common import loss_utils
from imps.point_augment.Augmentor.augmentor import Augmentor
from imps.point_augment.Classifier.classifier import RandlaClassifier

SCENE_DIR = '/app/mnt/scans/scene0040_00'

N_POINTS = int(1.5e5)
DEVICE = 'cuda'

IGNORED_LABELS = [0]

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
scene = ScanNetScene(SCENE_DIR)

surface_points, surface_colors, point_labels = scene.create_points_colors_labels_from_pc(N_POINTS)

In [3]:
# initializing a dictionary for storing points corresponding to the labels. I need this for extracting out table from the entire scene

label_point_dict = {} # creating empty dictionary: key : label; elements for each key: surface points corresponding to each label
for c in range(len(CLASS_NAMES)+1):
    res_list = np.where(point_labels == c) # finding the corresponding indices for a particular label use np.where[point_labels == c]

    surface_points_arr = np.array(surface_points)
    label_point_dict[c] = (list(surface_points_arr[res_list])) # extracting surface points corresponding to the indices. Objects are separated hence.

In [4]:
label_point_dict[7] = np.array(label_point_dict[7])
xyz_table = torch.FloatTensor(label_point_dict[7]).unsqueeze(0)
print(xyz_table.shape)   

torch.Size([1, 1416, 3])


In [5]:
# initializing a dictionary for storing indexes corresponding to the labels. I need this for extracting the indices for the corresponding labels so that I can extract the labels for a particular object from the decoder output

label_index_dict = {} 
for c in range(len(CLASS_NAMES)+1):
    res_list = np.where(point_labels == c) # extracting indices corresponding to a label
    label_index_dict[c] = res_list

In [6]:
# needed for calculation of randla_loss

class_counts = []
for c in range(len(CLASS_NAMES.keys())):
    class_counts.append(np.sum(point_labels == c))
class_counts = np.array(class_counts)

for ign in IGNORED_LABELS:
    class_counts[ign] = 0
class_weights = class_counts / class_counts.sum()


In [7]:
# excluding out point_mask for now.

def get_randla_loss(logits, labels, class_weights):
    class_weights = torch.from_numpy(class_weights).float().to(logits.device)
    logits = logits.reshape(-1, len(class_weights))
    labels = labels.reshape(-1)

    criterion = torch.nn.CrossEntropyLoss(weight=class_weights, reduction='none')
    output_loss = criterion(logits, labels)
    
   
    n_points = output_loss.shape[0]
        
    output_loss = output_loss.sum() / n_points
    
    return output_loss


In [8]:

def add_aug_instance_to_pc(aug_instance, surface_points, index_list):
    aug_instance = aug_instance.squeeze(0).transpose(1,0)
    aug_instance = aug_instance.cpu().detach()
    surface_points[index_list] = aug_instance
    return surface_points

In [9]:
def collect_label_pools(labels, input_pools, N_CLASS):
    label_pools = []
    prev_label_pools = labels

    for ip in input_pools:
        ip = ip.squeeze().cpu().numpy()
        n_pts, k = ip.shape
        ip = ip.reshape(-1)
        pool_pt_labels = prev_label_pools[ip]
        oh_labels = np.eye(N_CLASS)[pool_pt_labels]
        pooled_labels = oh_labels.reshape(n_pts, k, N_CLASS)
        pooled_votes = pooled_labels.sum(axis=1)

        prev_label_pools = pooled_votes.argmax(axis=-1)
        label_pools.append(prev_label_pools)
        
    return label_pools
 

In [10]:
def collect_instance_global_features(label_pools, features_encoder_list, target_label):
    instance_features_list = []

    for index, feature in enumerate(features_encoder_list):
        feature = feature.squeeze(0).squeeze(-1).permute(1,0)
        labels = label_pools[index]
        index_list = np.where(labels == target_label)
        instance_features = (feature[index_list])
        global_instance_features, indices =torch.max(instance_features, 0)
        instance_features_list.append(global_instance_features)
        global_features = torch.cat(instance_features_list,dim = 0)
        
    return global_features


In [11]:
# memory used : 717MiB 

features = torch.FloatTensor(surface_colors).unsqueeze(0).to(DEVICE)
xyz = torch.FloatTensor(surface_points).unsqueeze(0)
point_labels = torch.LongTensor(point_labels).unsqueeze(0).to(DEVICE)

In [12]:
# initializing colors, surface_points and labels as tensors. Prepare the input for neural net operation. Initialize RandLA net

# memory used: 22MiB



input_points, input_neighbors, input_pools, feat_shape = prepare_input(xyz, k=16, num_layers=3, encoder_dims = [8,32,64], sub_sampling_ratio=4, 
                                                           device=DEVICE)



In [13]:
randla = Randla(d_feature=3, d_in=8, encoder_dims=[8, 32, 64], device=DEVICE, num_class=len(CLASS_NAMES), interpolator='keops')

In [14]:
classifier = RandlaClassifier(feat_shape=feat_shape, N_CLASS=len(CLASS_NAMES)).cuda()

In [15]:
xyz_table = xyz_table.to(DEVICE)
xyz_table=xyz_table.transpose(2,1).contiguous()
label_index_dict[7] = np.array(label_index_dict[7]).squeeze()
label_index_dict[7] = label_index_dict[7].tolist()

In [16]:
# initialize epoch, dimension, augmentor, optimizers

# memory used: 22MiB

dim = 3
augmentor = Augmentor().cuda()
optimizer_r = optim.Adam(randla.parameters(), lr=1e-3) # optimizer for randla 
optimizer_a = torch.optim.Adam(                        # optimzer for PA
            augmentor.parameters(),
            lr=0.001,
            betas=(0.9, 0.999),
            eps=1e-08,
        )


In [17]:
randla.train()
augmentor=augmentor.train()

for epoch in range(0, 100):  
    # Original Point Cloud operation: 
    start_step1 = time.time()
    features_encoder_list = randla.encoder(features, input_points, input_neighbors, input_pools)
    label_pools = collect_label_pools(point_labels.squeeze().cpu().numpy(), input_pools, N_CLASS=len(CLASS_NAMES))
    global_features = collect_instance_global_features(label_pools, features_encoder_list, target_label=7)
    table_logits = classifier.forward(global_features.unsqueeze(0))

    # Noise variable added to generate feature matrix 
    
    noise = 0.02 * torch.randn(1, 1024).cuda() # choose randomly between 1 and 1024 for now. Here they have settled the feature array to have a dimension of 1024. We can change that?
    optimizer_a.zero_grad()
    optimizer_r.zero_grad()
    # Augment the table:
    aug_table = augmentor(xyz_table, noise) # augmented the table; #removed the batch-normal ??? 
    # here were logits 
    # Generating the modified point cloud by putting back the augmented table into it
    augmented_scene = add_aug_instance_to_pc(aug_table, surface_points, label_index_dict[7]) #function can be improved later
    augmented_scene = torch.FloatTensor(augmented_scene).unsqueeze(0)
    
    # Augmented Point Cloud Operation:
    
    # DOUBT:# I have to do this. As augmented point cloud is generated during each epoch and I need to prepare input for it. Any solution?
    input_points_aug, input_neighbors_aug, input_pools_aug, feat_shape = prepare_input(augmented_scene, k=16, num_layers=3, sub_sampling_ratio=4, 
                                                           device=DEVICE)  
    
    
    features_encoder_list_aug = randla.encoder(features, input_points_aug, input_neighbors_aug, input_pools_aug)
    label_pools_aug = collect_label_pools(point_labels.squeeze().cpu().numpy(), input_pools_aug, N_CLASS=len(CLASS_NAMES))
    global_features_aug = collect_instance_global_features(label_pools_aug, features_encoder_list_aug, target_label=7)
    aug_table_logits = classifier.forward(global_features_aug.unsqueeze(0))
    
 
    logits_aug = randla.decoder(features_encoder_list_aug, input_points_aug)   
    
    augLoss  = loss_utils.aug_loss(table_logits, aug_table_logits, 7)  
    randlaLoss = get_randla_loss(logits_aug, point_labels, class_weights) #how do I calculate randla_loss here? Do I calculate loss between logits_augmented and logits true?
    
    augLoss.backward(retain_graph=True)
    randlaLoss.backward(retain_graph=True)

    optimizer_r.step()
    optimizer_a.step()
    start_step2 = time.time()
    
    print(f"Epoch {epoch+1}, Augmentor loss:{round(augLoss.item(), 4)}, RandlaNet loss:{round(randlaLoss.item(), 4)}, Epoch_duration: {round((start_step2-start_step1),2)} s")



Epoch 1, Augmentor loss:3.5269, RandlaNet loss:0.4624, Epoch_duration: 2.07 s
Epoch 2, Augmentor loss:3.3411, RandlaNet loss:0.432, Epoch_duration: 1.94 s
Epoch 3, Augmentor loss:3.2831, RandlaNet loss:0.4127, Epoch_duration: 1.95 s
Epoch 4, Augmentor loss:3.3064, RandlaNet loss:0.3985, Epoch_duration: 1.96 s
Epoch 5, Augmentor loss:3.2021, RandlaNet loss:0.387, Epoch_duration: 1.93 s
Epoch 6, Augmentor loss:3.1728, RandlaNet loss:0.3764, Epoch_duration: 1.93 s
Epoch 7, Augmentor loss:3.216, RandlaNet loss:0.3669, Epoch_duration: 1.96 s
Epoch 8, Augmentor loss:3.0239, RandlaNet loss:0.3592, Epoch_duration: 1.9 s
Epoch 9, Augmentor loss:2.9211, RandlaNet loss:0.3531, Epoch_duration: 1.97 s
Epoch 10, Augmentor loss:2.8224, RandlaNet loss:0.3473, Epoch_duration: 1.89 s
Epoch 11, Augmentor loss:2.9257, RandlaNet loss:0.3425, Epoch_duration: 1.94 s
Epoch 12, Augmentor loss:2.6091, RandlaNet loss:0.3376, Epoch_duration: 1.9 s
Epoch 13, Augmentor loss:2.8638, RandlaNet loss:0.3323, Epoch_dura

In [18]:
torch.save(randla.state_dict(), '../../processed/saved_models/PA_scene0040_00')

In [None]:
### Rotating, scaling and translating a table individually and putting it back to the original point cloud ###


surface_points_table = surface_points_table_[np.newaxis, :, :]
point_cloud_rotate = PointcloudRotatebyAngle(surface_points_table)

print(surface_points_table.shape)
surface_points_table = torch.from_numpy(surface_points_table)

In [None]:
#Augmentation: Rotation

point_cloud_table_rotated = point_cloud_rotate(surface_points_table, 1)


In [None]:
# Augmentation: Scale and Translate

point_cloud_scale_and_translate = PointcloudScaleAndTranslate()
point_cloud_table_st = point_cloud_scale_and_translate(point_cloud_table_rotated)

In [None]:
# Adding augmented point cloud back to the original point cloud

point_cloud_table_st = point_cloud_table_st.cpu()
surface_points1[res_list] = point_cloud_table_st

print(surface_points1.shape)