## Creating a Regressor for the plane class

In [1]:
# imports
from dataset import SHREC2022Primitives
from networks import MinkowskiFCNN
import transforms as t
from losses import Losses
import MinkowskiEngine as ME


import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import os

In [2]:
# configuration & hyperparameters
path = "/home/ioannis/Desktop/programming/data/SHREC/SHREC2022/dataset"
batch_size = 128
lr=1e-3
num_epochs = 20
eval_step = 1000000000

In [3]:
def minkowski_collate(list_data):
    coordinates, features, labels = ME.utils.sparse_collate(
        [d['x'] for d in list_data],
        [d['x'] for d in list_data],
        [d['y'].unsqueeze(0) for d in list_data],
        dtype = torch.float32
    )
    
    # collating other data
    norm_factors = []
    shifts = []
    inv_rotations = []
    means = []
    
    for d in list_data:
        norm_factors.append(d['norm_factor'])
        shifts.append(d['shift'])
        if 'inverse_rotation' in d.keys():
            inv_rotations.append(d['inverse_rotation'])
        means.append(d['mean'])
        
    norm_factors = torch.stack(norm_factors)
    shifts = torch.stack(shifts)
    
    if len(inv_rotations) > 0:
        inv_rotations = torch.stack(inv_rotations)
    
    means = torch.stack(means)
    
    return {
        "coordinates"   : coordinates, 
        "features"      : features,
        "labels"        : labels,
        "means"         : means,
        "norm_factors"  : norm_factors,
        "shifts"        : shifts,
        "inv_rotations" : inv_rotations if len(inv_rotations) > 0 else None
    }

In [4]:
# Dataset and dataloader
train_transforms = [t.Translate(), 
                    t.SphereNormalization(), 
                    t.Initialization(),
                    t.RandomRotate(180, 0),
                    t.RandomRotate(180, 1),
                    t.RandomRotate(180, 2),
                    t.GaussianNoise(),
                    t.GetMean()]

valid_transforms = [t.Translate(), 
                    t.SphereNormalization()]

t_dataset = SHREC2022Primitives(path, 
                                train=True, 
                                valid=False, 
                                valid_split=0.2, 
                                transform=train_transforms,
                                category="plane")

train_loader = DataLoader(t_dataset, 
                          batch_size=batch_size, 
                          shuffle=True, 
                          collate_fn=minkowski_collate, 
                         num_workers=8)

v_dataset = SHREC2022Primitives(path, 
                                train=True, 
                                valid=True, 
                                valid_split=0.2, 
                                transform=valid_transforms,
                                category="plane")

valid_loader = DataLoader(v_dataset, 
                          batch_size=batch_size, 
                          shuffle=False, 
                          collate_fn=minkowski_collate, 
                          num_workers=8)


Specified split already exists. Using the existing one.
Specified split already exists. Using the existing one.


In [5]:
# network and device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

net = net = MinkowskiFCNN(in_channel = 3, out_channel = 3).to(device)

Using device: cuda:0


In [6]:
def create_input_batch(batch, device="cuda", quantization_size=0.05):
    batch["coordinates"][:, 1:] = batch["coordinates"][:, 1:] / quantization_size
    return ME.TensorField(
        coordinates=batch["coordinates"],
        features=batch["features"],
        device=device
    )

In [7]:
def transform_outputs(output, scale, shift, rotation_mat=None):
    
    #output: B x 9
    #scale: B x 1
    #shift: B x 3
    #rotation_mat: B x 3 x 3
    
    scale = scale.to(output.device).unsqueeze(-1)
    shift = shift.to(output.device)
    
    normal = output[:,0:3]
    vertex = output[:,3:]
    if rotation_mat is not None:
        rotation_mat = rotation_mat.to(output.device)
        #applying inverse rotation to normal vectors
        normal = (rotation_mat @ normal.unsqueeze(-1)).squeeze(-1)
    #applying inverse rotation, scaling and translation to vertices
        vertex = (rotation_mat @ vertex.unsqueeze(-1)).squeeze(-1)
    vertex = vertex * scale
    vertex = vertex - shift
    
    return normal, vertex


def transform_plane_outputs(plane_pred, trans):
        
        #plane_pred: B x 6
        #scale: B x 1
        #shift: B x 3
        #rotation_mat: B x 3 x 3

        #print(plane_pred.shape)
        
        
        scale, shift, rotation_mat = trans["norm_factors"], trans["shifts"], trans["inv_rotations"]
        
        #print(scale.shape)
        #print(shift.shape)
        #print(rotation_mat.shape)
        
        scale = scale.to(plane_pred.device).unsqueeze(-1)
        shift = shift.to(plane_pred.device)
        
        normal = plane_pred[:,0:3]
        vertex = plane_pred[:,3:]
        if rotation_mat is not None:
            rotation_mat = rotation_mat.to(plane_pred.device)
            #applying inverse rotation to normal vectors
            normal = (rotation_mat @ normal.unsqueeze(-1)).squeeze(-1)
        #applying inverse rotation, scaling and translation to vertices
            vertex = (rotation_mat @ vertex.unsqueeze(-1)).squeeze(-1)
        vertex = vertex * scale
        vertex = vertex - shift
        
        return normal, vertex

In [8]:
# optimizer and losses
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
criterion = Losses()

In [9]:
# Training loop

for epoch in range(num_epochs):
    # for loss and accuracy tracking the training set
    m_loss = 0
    m_vertex_loss = 0
    m_normal_loss = 0
    
    for batch in tqdm(train_loader):

        optimizer.zero_grad()

        labels = batch["labels"].to(device)
        classes = labels[:,0].sum()

        actual_normal = labels[:,1:4]
        actual_vertex = labels[:,4:7]
         
        batch_size = labels.shape[0]

        minknet_input = create_input_batch(
            batch, 
            device=device,
            quantization_size=0.05
        )

        pred = net(minknet_input)
        #print(batch.keys())
        pred_vertex = batch['means'].to(device)    
        #print(pred.shape)
        #print(pred_vertex.shape)
        pred = torch.cat([pred, pred_vertex], dim=-1)
        
        pred_normal, pred_vertex = transform_plane_outputs(pred, 
                                    {"norm_factors" : batch["norm_factors"], 
                                     "shifts"       : batch["shifts"], 
                                     "inv_rotations": batch["inv_rotations"]})
        
        normal_loss = criterion.AxisToAxisLoss(pred_normal, actual_normal)
        vertex_loss = criterion.PointToPlaneLoss(pred_vertex, actual_vertex, actual_normal)
        
#         print(normal_loss)
#         print(vertex_loss)
                
        loss = normal_loss.mean() + vertex_loss.mean()
        loss.backward()
        optimizer.step()
        
        m_normal_loss += normal_loss.mean().item()
        m_vertex_loss += vertex_loss.mean().item()
        m_loss += loss.item()

    m_loss /= len(train_loader)
    m_normal_loss /= len(train_loader)
    m_vertex_loss /= len(train_loader)
    
    print(f" Epoch: {epoch} | Training: loss = {m_loss}")
    print(f" Normal loss: {m_normal_loss} | Vertex_loss: {m_vertex_loss}")
   

print("Done!")

100%|██████████| 58/58 [00:14<00:00,  4.04it/s]


 Epoch: 0 | Training: loss = 0.1597797566189848
 Normal loss: 0.13015334455874458 | Vertex_loss: 0.0296264108398865


 72%|███████▏  | 42/58 [00:11<00:04,  3.53it/s]


KeyboardInterrupt: 

In [None]:
'''
    if (epoch+1) % eval_step == 0:
        m_loss = 0
        m_normal_loss = 0
        m_vertex_loss = 0
        net.eval()
        with torch.no_grad():
            for batch in tqdm(valid_loader):

                labels = batch["labels"].to(device)
                actual_normal = labels[:,1:4]
                actual_vertex = labels[:,4:7]
                batch_size = labels.shape[0]
                
                net_in = create_input_batch(
                    batch, 
                    device=device, 
                    quantization_size=0.05
                )

                pred = net(net_in)
              
                pred_normal, pred_vertex = transform_outputs(pred, 
                                                     batch["norm_factors"], 
                                                     batch["shifts"], 
                                                     batch["inv_rotations"])

                normal_loss = criterion.AxisToAxisLoss(pred_normal, actual_normal)
                vertex_loss = criterion.PointToPlaneLoss(pred_vertex, actual_vertex, actual_normal)

                loss = normal_loss + vertex_loss
          
                m_normal_loss += normal_loss.item()
                m_vertex_loss += vertex_loss.item()
                m_loss += loss.item()
                
            m_loss /= len(valid_loader)
            m_normal_loss /= len(valid_loader)
            m_vertex_loss /= len(valid_loader)

            print(f" --------->  Evaluation: loss = {m_loss}")
            print(f" Normal loss: {m_normal_loss} | Vertex_loss: {m_vertex_loss}")

        # setting network back to training mode
        net.train()
'''
pass

In [None]:
train_transforms = [t.Translate(), 
                    t.SphereNormalization(), 
                    t.Initialization(),
                    t.RandomRotate(180, 0),
                    t.RandomRotate(180, 1),
                    t.RandomRotate(180, 2),
                    t.GaussianNoise()]

t_dataset = SHREC2022Primitives(path, 
                                train=True, 
                                valid=False, 
                                valid_split=0.2, 
                                category="plane")


In [None]:
import open3d as o3d

def makeO3Dpc(points):
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(points.numpy())
    return pcd

def visualize_pointcloud(pcd):
    if not isinstance(pcd, o3d.geometry.PointCloud):
        pcd = makeO3Dpc(pcd)
    o3d.visualization.draw_geometries([pcd])

In [None]:
ini_sample = t_dataset[0]
#visualize_pointcloud(ini_sample['x'])

In [None]:
from copy import copy
sample = copy(ini_sample)
sample = train_transforms[2](sample)
sample = train_transforms[3](sample)
sample = train_transforms[4](sample)
sample = train_transforms[5](sample)

print((sample['x']-ini_sample['x']).sum())

#visualize_pointcloud(sample['x'])

In [None]:
print(sample['inverse_rotation'])
sample['x'] = (sample['inverse_rotation'].unsqueeze(0) @ sample['x'].unsqueeze(-1)).squeeze(-1)
print(sample['x'].shape)
#visualize_pointcloud(sample['x'].squeeze(0).cpu().detach())
print((sample['x'] - ini_sample['x']).sum())