# Prototype: PRS-Net

In [1]:
import open3d as o3d
import numpy as np

import torch
import torch.nn as nn
from torch.autograd import Variable

import sys
sys.path.insert(0, '../impl/utils/')
import voxel_processing as vp

batch_size = 64
num_classes = 10
learning_rate = 0.001
num_epochs = 20

# device will determine whether to run the training on GPU or CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


device(type='cuda')

In [2]:
# Load datasets
data = vp.read_dataset_from_path('../data/voxel_data/')
data = vp.prepare_dataset(data)

Reading: ../data/voxel_data/voxel_grid_1.obj
Reading: ../data/voxel_data/voxel_grid_0.omap
Reading: ../data/voxel_data/voxel_grid_1.offsetvec
Reading: ../data/voxel_data/voxel_grid_0.gridpoints
Reading: ../data/voxel_data/voxel_grid_0.obj
Reading: ../data/voxel_data/voxel_grid_1.omap
Reading: ../data/voxel_data/voxel_grid_0.offsetvec
Reading: ../data/voxel_data/voxel_grid_1.gridpoints
2 dataset(s) have been processed. 
2 dataset(s) have been prepared. 


Implementing PRS-Net

In [6]:
class PRS_Encoder(nn.Module):
    def __init__(self, ) -> None:
        super().__init__()
        
        leaky_ReLU_slope = 0.2
        
        # 32^3x1
        self.conv_layer0 = nn.Conv3d(in_channels=1, out_channels=4, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        self.max_pool0 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
        self.leaky_relu0 = nn.LeakyReLU(negative_slope=leaky_ReLU_slope)
        
        # 16^3x4
        self.conv_layer1 = nn.Conv3d(in_channels=4, out_channels=8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        self.max_pool1 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
        self.leaky_relu1 = nn.LeakyReLU(negative_slope=leaky_ReLU_slope)
        
        # 8^3x8
        self.conv_layer2 = nn.Conv3d(in_channels=8, out_channels=16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        self.max_pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
        self.leaky_relu2 = nn.LeakyReLU(negative_slope=leaky_ReLU_slope)
        
        # 4^3x16
        self.conv_layer3 = nn.Conv3d(in_channels=16, out_channels=32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        self.max_pool3 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
        self.leaky_relu3 = nn.LeakyReLU(negative_slope=leaky_ReLU_slope)
        
        # 2^3x32
        self.conv_layer4 = nn.Conv3d(in_channels=32, out_channels=64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        self.max_pool4 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
        self.leaky_relu4 = nn.LeakyReLU(negative_slope=leaky_ReLU_slope)
        # 1^3x64
    
    def forward(self, voxels):
        out = self.conv_layer0(voxels)
        out = self.max_pool0(out)
        out = self.leaky_relu0(out)
        
        out = self.conv_layer1(out)
        out = self.max_pool1(out)
        out = self.leaky_relu1(out)
        
        out = self.conv_layer2(out)
        out = self.max_pool2(out)
        out = self.leaky_relu2(out)
        
        out = self.conv_layer3(out)
        out = self.max_pool3(out)
        out = self.leaky_relu3(out)
        
        out = self.conv_layer4(out)
        out = self.max_pool4(out)
        out = self.leaky_relu4(out)
        
        return out
        

In [7]:
class PRS_Plane_Predictor(nn.Module):
    def __init__(self, ) -> None:
        super().__init__()
        
        # implicit symmetry planes: 4 features, aX + bY + cZ + d = 0
        self.fc0 = nn.Linear(64, 32)
        self.relu0 = nn.LeakyReLU()
        self.fc1 = nn.Linear(32, 16)
        self.relu1 = nn.LeakyReLU()
        self.fc2 = nn.Linear(16, 4)
        
    def forward(self, features):
        out = self.fc0(features)
        out = self.relu0(out)
        out = self.fc1(out)
        out = self.relu1(out)
        out = self.fc2(out)
        
        return out

In [8]:
class PRS_Quaterion_Predictor(nn.Module):
    def __init__(self, ) -> None:
        super().__init__()
        
        # quaterion rotation: 4 features, a + bi + cj + dk
        self.fc0 = nn.Linear(64, 32)
        self.relu0 = nn.LeakyReLU()
        self.fc1 = nn.Linear(32, 16)
        self.relu1 = nn.LeakyReLU()
        self.fc2 = nn.Linear(16, 4)
        
    def forward(self, features):
        out = self.fc0(features)
        out = self.relu0(out)
        out = self.fc1(out)
        out = self.relu1(out)
        out = self.fc2(out)
        
        return out

In [11]:
class PRSNet(nn.Module):
    def __init__(self, ) -> None:
        super(PRSNet, self).__init__()
        
        self.encoder = PRS_Encoder()
        self.plane_predictor0 = PRS_Plane_Predictor()
        self.plane_predictor1 = PRS_Plane_Predictor()
        self.plane_predictor2 = PRS_Plane_Predictor()
        self.quaterion_predictor0 = PRS_Quaterion_Predictor()
        self.quaterion_predictor1 = PRS_Quaterion_Predictor()
        self.quaterion_predictor2 = PRS_Quaterion_Predictor()
        
    def forward(self, batch_voxels):
        out0 = self.encoder(batch_voxels)
        
        out0 = out0.reshape(-1, 64)
        
        plane0 = self.plane_predictor0(out0)
        plane1 = self.plane_predictor1(out0)
        plane2 = self.plane_predictor2(out0)
        
        quat0 = self.quaterion_predictor0(out0)
        quat1 = self.quaterion_predictor1(out0)
        quat2 = self.quaterion_predictor2(out0)
        
        return [plane0, plane1, plane2], [quat0, quat1, quat2]

In [16]:
batch_voxels = data[1]
# test_layer = nn.Conv3d(in_channels=1, out_channels=1, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
test_layer = PRSNet()
result = test_layer.forward(batch_voxels)

torch.Size([2, 4])

In [None]:
def compute_single_dist_sum(grid_points, query_points):
    '''
    Compute the sum of 'shortest distances' between points 
    and the closest point on the mesh's surface in their corresponding grids
    '''
    query_indices = vp.compute_std_grid_indices(query_points)
    query_indices = query_indices.T
    displacements = grid_points[query_indices] - query_points
    vector_norm = torch.linalg.vector_norm(displacements, dim=1)
    result = torch.sum(vector_norm)
    return result

In [None]:
class PRSNet_Symm_Dist_Loss(nn.Module):
    '''
    PRSNet Symmetry Distance Loss
    '''
    def __init__(self, ) -> None:
        super().__init__()

    def compute_dist_sum(batch_grid_points, batch_query_points):
        '''
        Compute, for each sample, the sum of 'shortest distances' between points 
        and the closest point on the mesh's surface in their corresponding grids
        '''
        batch_query_indices = vp.compute_batch_std_grid_indices(batch_query_points)
        batch_displacements = batch_grid_points[batch_query_indices] - batch_query_points
        vector_norm = torch.linalg.vector_norm(batch_displacements, dim=2)
        result = torch.sum(vector_norm, dim=1)
        return result

    def apply_planar_transform(sample_points, plane):
        '''
        Compute planar transformation of the given sample points.
        
        The formula is given by `q' = q - 2 <q - r, n> n`, where:
        - q is the target point, q' is the point after symmetric transformation,
        - r is the orthogonal displacement vector of the plane, and
        - n is the normalized normal vector of the plane.
        '''
        n = plane[:, 0:3] / torch.norm(plane[:, 0:3])
        q = sample_points
        q_prime = q - 2. * (torch.matmul(q, n) + plane[:, 3]) * n
        
        return q_prime
    
    def forward(self, batch_planar_features, batch_axial_features, batch_grid_points, batch_sample_points):
        planar_loss = 0
        axial_loss = 0
        return planar_loss + axial_loss

In [None]:
class PRSNet_Reg_Loss(nn.Module):
    '''
    PRS Regularization Loss
    '''
    def __init__(self, ) -> None:
        super().__init__()
    
    def forward(self, features, grid_points):
        # TODO
        pass

In [None]:
class PRSNet_Loss(nn.Module):
    def __init__(self, w_reg) -> None:
        super().__init__()
        self.symmetry_loss = PRSNet_Symm_Dist_Loss()
        self.reg_loss = PRSNet_Reg_Loss()
        self.w_reg = w_reg
        
    
    def forward(self, features, grid_points, sample_points):
        # TODO
        pass