# Prototype: PRS-Net

In [2]:
import open3d as o3d
import numpy as np

import torch
import torch.nn as nn
from torch.autograd import Variable

import sys
sys.path.insert(0, '../impl/utils/')
import voxel_processing as vp

# device will determine whether to run the training on GPU or CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


device(type='cuda')

In [3]:
# Prepare dataset
vp.preprocess_files('../data/obj_data/', '../data/voxel_data/')

Processing: ../data/obj_data/shapenet/1021a0914a7207aff927ed529ad90a11/models/model_normalized.obj -> ../data/voxel_data/voxel_grid_0
1 file(s) have been processed.


In [4]:
# Load datasets
data = vp.read_dataset_from_path('../data/voxel_data/')
data = vp.prepare_dataset(data)
data

Reading: ../data/voxel_data/voxel_grid_1.obj


Reading: ../data/voxel_data/voxel_grid_0.omap
Reading: ../data/voxel_data/voxel_grid_1.offsetvec
Reading: ../data/voxel_data/voxel_grid_0.gridpoints
Reading: ../data/voxel_data/voxel_grid_0.obj
Reading: ../data/voxel_data/voxel_grid_1.omap
Reading: ../data/voxel_data/voxel_grid_0.offsetvec
Reading: ../data/voxel_data/voxel_grid_1.gridpoints
2 dataset(s) have been processed. 
2 dataset(s) have been prepared. 


[(TriangleMesh with 6130 points and 8448 triangles.,
  tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]],
  
           [[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]],
  
           [[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]],
  
           ...,
  
           [[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ...

Implementing PRS-Net

In [5]:
class PRSNet_Encoder(nn.Module):
    def __init__(self, ) -> None:
        super().__init__()
        
        leaky_ReLU_slope = 0.2
        
        # 32^3x1
        self.conv_layer0 = nn.Conv3d(in_channels=1, out_channels=4, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        self.max_pool0 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
        self.leaky_relu0 = nn.LeakyReLU(negative_slope=leaky_ReLU_slope)
        
        # 16^3x4
        self.conv_layer1 = nn.Conv3d(in_channels=4, out_channels=8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        self.max_pool1 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
        self.leaky_relu1 = nn.LeakyReLU(negative_slope=leaky_ReLU_slope)
        
        # 8^3x8
        self.conv_layer2 = nn.Conv3d(in_channels=8, out_channels=16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        self.max_pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
        self.leaky_relu2 = nn.LeakyReLU(negative_slope=leaky_ReLU_slope)
        
        # 4^3x16
        self.conv_layer3 = nn.Conv3d(in_channels=16, out_channels=32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        self.max_pool3 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
        self.leaky_relu3 = nn.LeakyReLU(negative_slope=leaky_ReLU_slope)
        
        # 2^3x32
        self.conv_layer4 = nn.Conv3d(in_channels=32, out_channels=64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        self.max_pool4 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
        self.leaky_relu4 = nn.LeakyReLU(negative_slope=leaky_ReLU_slope)
        # 1^3x64
    
    def forward(self, voxels):
        out = self.conv_layer0(voxels)
        out = self.max_pool0(out)
        out = self.leaky_relu0(out)
        
        out = self.conv_layer1(out)
        out = self.max_pool1(out)
        out = self.leaky_relu1(out)
        
        out = self.conv_layer2(out)
        out = self.max_pool2(out)
        out = self.leaky_relu2(out)
        
        out = self.conv_layer3(out)
        out = self.max_pool3(out)
        out = self.leaky_relu3(out)
        
        out = self.conv_layer4(out)
        out = self.max_pool4(out)
        out = self.leaky_relu4(out)
        
        return out
        

In [6]:
class PRSNet_Plane_Predictor(nn.Module):
    def __init__(self, ) -> None:
        super().__init__()
        
        # implicit symmetry planes: 4 features, aX + bY + cZ + d = 0
        self.fc0 = nn.Linear(64, 32)
        self.relu0 = nn.LeakyReLU()
        self.fc1 = nn.Linear(32, 16)
        self.relu1 = nn.LeakyReLU()
        self.fc2 = nn.Linear(16, 4)
    
    def set_initial_bias(self, feature):
        self.fc2.bias.data = feature.clone()
        
    def forward(self, features):
        out = self.fc0(features)
        out = self.relu0(out)
        out = self.fc1(out)
        out = self.relu1(out)
        out = self.fc2(out)
        
        return out

In [7]:
class PRSNet_Quaternion_Predictor(nn.Module):
    def __init__(self, ) -> None:
        super().__init__()
        
        # quaterion rotation: 4 features, a + bi + cj + dk
        self.fc0 = nn.Linear(64, 32)
        self.relu0 = nn.LeakyReLU()
        self.fc1 = nn.Linear(32, 16)
        self.relu1 = nn.LeakyReLU()
        self.fc2 = nn.Linear(16, 4)
    
    def set_initial_bias(self, feature):
        self.fc2.bias.data = feature.clone()
        
    def forward(self, features):
        out = self.fc0(features)
        out = self.relu0(out)
        out = self.fc1(out)
        out = self.relu1(out)
        out = self.fc2(out)
        
        return out

In [10]:
class PRSNet(nn.Module):
    def __init__(self, ) -> None:
        super(PRSNet, self).__init__()
        
        self.encoder = PRSNet_Encoder()
        
        self.plane_predictor0 = PRSNet_Plane_Predictor()
        self.plane_predictor1 = PRSNet_Plane_Predictor()
        self.plane_predictor2 = PRSNet_Plane_Predictor()
        
        self.plane_predictor0.set_initial_bias(torch.tensor([1., 0., 0., 0.]))
        self.plane_predictor1.set_initial_bias(torch.tensor([0., 1., 0., 0.]))
        self.plane_predictor2.set_initial_bias(torch.tensor([0., 0., 1., 0.]))
        
        self.quaternion_predictor0 = PRSNet_Quaternion_Predictor()
        self.quaternion_predictor1 = PRSNet_Quaternion_Predictor()
        self.quaternion_predictor2 = PRSNet_Quaternion_Predictor()
        
        cos_theta = torch.cos(torch.tensor(torch.pi / 2))
        sin_theta = torch.sin(torch.tensor(torch.pi / 2))
        
        self.quaternion_predictor0.set_initial_bias(torch.tensor([cos_theta, sin_theta, 0., 0.]))
        self.quaternion_predictor1.set_initial_bias(torch.tensor([cos_theta, 0., sin_theta, 0.]))
        self.quaternion_predictor2.set_initial_bias(torch.tensor([cos_theta, 0., 0., sin_theta]))
        
    def forward(self, batch_voxels):
        if batch_voxels.dim == 4:
            batch_voxels = batch_voxels.unsqueeze(0)
        out0 = self.encoder(batch_voxels)
        
        out0 = out0.reshape(-1, 64)
        M = out0.shape[0]
        
        plane0 = self.plane_predictor0(out0)
        plane1 = self.plane_predictor1(out0)
        plane2 = self.plane_predictor2(out0)
        
        plane0 = plane0.reshape(M, 4)
        plane1 = plane0.reshape(M, 4)
        plane2 = plane0.reshape(M, 4)
        
        quat0 = self.quaternion_predictor0(out0)
        quat1 = self.quaternion_predictor1(out0)
        quat2 = self.quaternion_predictor2(out0)
        
        quat0 = quat0.reshape(M, 4)
        quat1 = quat1.reshape(M, 4)
        quat2 = quat2.reshape(M, 4)
        
        return torch.stack([plane0, plane1, plane2], dim=1), torch.stack([quat0, quat1, quat2], dim=1)

In [11]:
batch_voxels = data[0][1]
# test_layer = nn.Conv3d(in_channels=1, out_channels=1, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
prsnet_test = PRSNet()
prsnet_result = prsnet_test.forward(batch_voxels.to(torch.float32))
print(prsnet_result[0].shape, prsnet_result[1].shape)
prsnet_result[0]

torch.Size([1, 3, 4]) torch.Size([1, 3, 4])


tensor([[[ 1.0158, -0.0934,  0.0044, -0.0354],
         [ 1.0158, -0.0934,  0.0044, -0.0354],
         [ 1.0158, -0.0934,  0.0044, -0.0354]]], grad_fn=<StackBackward0>)

In [12]:
class PRSNet_Symm_Dist_Loss(nn.Module):
    '''
    PRSNet Symmetry Distance Loss
    '''
    def __init__(self, ) -> None:
        super().__init__()

    def compute_batch_std_grid_indices(self, batch_query_points: torch.Tensor):
        '''
        `batch_query_points` should be of shape `(M, N, 3)`.
        
        Return a tensor of shape `(M, 3, N)`, where
        - `M` is the number of samples inside the batch,
        - `N` is the number of queries inside a single sample.
        '''
        M = batch_query_points.shape[0]
        N = batch_query_points.shape[1]
        
        device = batch_query_points.device
        
        tmp0 = batch_query_points.transpose(1, 2)
        xs = tmp0[:, 0].reshape((1, -1)).contiguous()
        ys = tmp0[:, 1].reshape((1, -1)).contiguous()
        zs = tmp0[:, 2].reshape((1, -1)).contiguous()
        
        x = (xs * 16).floor().to(torch.int).clamp(-16, 15) + 16
        y = (ys * 16).floor().to(torch.int).clamp(-16, 15) + 16
        z = (zs * 16).floor().to(torch.int).clamp(-16, 15) + 16
        
        x = x.reshape((M, 1, N))
        y = y.reshape((M, 1, N))
        z = z.reshape((M, 1, N))
        
        result0 = torch.arange(0, M, dtype=x.dtype, device=device).reshape(M, 1, 1).repeat([1, 1, N])
        result1 = torch.cat([x, y, z], dim=1)
        
        return torch.cat([*torch.cat([result0, result1], dim=1)], dim=1)

    def compute_batch_dist_sum(self, batch_grid_points, batch_query_points):
        '''
        Compute, for each sample, the sum of 'shortest distances' between points 
        and the closest point on the mesh's surface in their corresponding grids.
        
        Return summed_distance of shape (M, 1), where
        - `M` is batch size.
        '''
        M = batch_grid_points.shape[0]
        m, x, y, z = self.compute_batch_std_grid_indices(batch_query_points)
        
        g = batch_grid_points[m, x, y, z].reshape(-1, 3)
        q = batch_query_points.reshape(-1, 3)
        batch_displacements = g - q
        vector_norm = torch.linalg.vector_norm(batch_displacements, dim=1)
        result = torch.sum(vector_norm, dim=0)
        
        return result

    def apply_planar_transform(self, batch_plane, batch_sample_points):
        '''
        Compute, for each sample, sample points after planar reflective transformation.
        
        The formula is given by `q' = q - 2 <q - r, n> n`, where:
        - q is the target point, q' is the point after symmetric transformation,
        - r is the orthogonal displacement vector of the plane, and
        - n is the normalized normal vector of the plane.
        
        Return a tensor of shape `(M, N, D, 3)`
        '''
        M = batch_sample_points.shape[0]
        N = batch_sample_points.shape[1]
        D = batch_plane.shape[1]
        
        q = batch_sample_points
        
        n_norm = torch.norm(batch_plane[:, :, 0:3], dim=2).reshape(M, D, 1)
        n = (batch_plane[:, :, 0:3] / n_norm).transpose(1, 2)
        d = batch_plane[:, :, 3].reshape(M, 1, D).repeat([1, N, 1])
        coeff = (torch.einsum('bij,bjk->bik', q, n) + d) * 2
        
        coeff1 = coeff.reshape(M, N, D, 1).repeat([1, 1, 1, 3])
        n0 = n.transpose(1, 2).reshape(M, 1, D, 3).repeat([1, N, 1, 1])
        tmp = coeff1 * n0
        
        q0 = q.reshape(M, N, 1, 3).repeat([1, 1, D, 1])
        result = q0 - tmp
        
        return result
    
    def batch_quat_normalize(self, batch_quaternions):
        M = batch_quaternions.shape[0]
        D = batch_quaternions.shape[1]
        
        q_norm = torch.norm(batch_quaternions, dim=2).reshape(M, D, 1)
        q = batch_quaternions / q_norm
        return q
    
    def batch_multiply_quaternion(self, r, s):
        M = r.shape[0]
        N = r.shape[1]
        D = r.shape[2]
        
        result_r = torch.zeros((M, N, D), requires_grad=True)
        result_i = torch.zeros((M, N, D), requires_grad=True)
        result_j = torch.zeros((M, N, D), requires_grad=True)
        result_k = torch.zeros((M, N, D), requires_grad=True)
        result_r = r[:, :, :, 0] * s[:, :, :, 0] - r[:, :, :, 1] * s[:, :, :, 1] \
                    - r[:, :, :, 2] * s[:, :, :, 2] - r[:, :, :, 3] * s[:, :, :, 3]
        result_i = r[:, :, :, 1] * s[:, :, :, 0] + r[:, :, :, 0] * s[:, :, :, 1] \
                    + r[:, :, :, 2] * s[:, :, :, 3] - r[:, :, :, 3] * s[:, :, :, 2]
        result_j = r[:, :, :, 2] * s[:, :, :, 0] + r[:, :, :, 0] * s[:, :, :, 2] \
                    + r[:, :, :, 3] * s[:, :, :, 1] - r[:, :, :, 1] * s[:, :, :, 3]
        result_k = r[:, :, :, 3] * s[:, :, :, 0] + r[:, :, :, 0] * s[:, :, :, 3] \
                    + r[:, :, :, 1] * s[:, :, :, 2] - r[:, :, :, 2] * s[:, :, :, 1]
                    
        result_r = result_r.reshape(M, N, D, 1)
        result_i = result_i.reshape(M, N, D, 1)
        result_j = result_j.reshape(M, N, D, 1)
        result_k = result_k.reshape(M, N, D, 1)
        return torch.cat([result_r, result_i, result_j, result_k], dim=3)
    
    def apply_quaternion_rotation(self, batch_quaternions, batch_sample_points):
        '''
        Compute, for each sample, sample points after rotation using quaternions.
        '''
        M = batch_sample_points.shape[0]
        N = batch_sample_points.shape[1]
        D = batch_quaternions.shape[1]
        
        device = batch_quaternions.device
        
        p = torch.cat([torch.zeros((M, N, 1), device=device), batch_sample_points], dim=2).reshape(M, N, 1, 4).repeat([1, 1, D, 1])
        # normalized the quaternion first
        q = self.batch_quat_normalize(batch_quaternions)
        # prepare for quaternion multiplication
        q0 = q.reshape(M, 1, D, 4).repeat([1, N, 1, 1])
        q0p_im = -q0[:, :, :, 1:4].clone()
        q0p_re = q0[:, :, :, 0].clone().reshape(M, N, D, 1)
        q0p = torch.cat([q0p_re, q0p_im], dim=3)
        
        tmp0 = self.batch_multiply_quaternion(p, q0)
        tmp0 = self.batch_multiply_quaternion(q0p, tmp0)
        
        return tmp0[:, :, :, 0:3].clone()
    
    def forward(self, batch_planar_features, batch_quat_features, batch_grid_points, batch_sample_points):
        if batch_planar_features.dim == 3:
            batch_planar_features = batch_planar_features.unsqueeze(0)
        if batch_quat_features.dim == 3:
            batch_quat_features = batch_quat_features.unsqueeze(0)
        if batch_grid_points.dim == 4:
            batch_grid_points = batch_grid_points.unsqueeze(0)
        if batch_sample_points.dim == 2:
            batch_sample_points = batch_sample_points.unsqueeze(0)

        M = batch_sample_points.shape[0]
        N = batch_sample_points.shape[1]
        D0 = batch_planar_features.shape[1]
        D1 = batch_quat_features.shape[1]
        
        p_trans_points = self.apply_planar_transform(batch_planar_features, batch_sample_points)
        p_losses = self.compute_batch_dist_sum(batch_grid_points, p_trans_points.reshape(M, -1, 3))
        planar_loss = torch.sum(p_losses)
        
        q_trans_points = self.apply_quaternion_rotation(batch_quat_features, batch_sample_points)
        q_losses = self.compute_batch_dist_sum(batch_grid_points, q_trans_points.reshape(M, -1, 3))
        quat_loss = torch.sum(q_losses)
        
        return planar_loss + quat_loss

In [13]:
# Test `compute_batch_std_grid_indices`
batch_sample_points = data[0][4].unsqueeze(0)
test_layer = PRSNet_Symm_Dist_Loss()
test_layer.compute_batch_std_grid_indices(batch_sample_points)

tensor([[ 0,  0,  0,  ...,  0,  0,  0],
        [16, 16, 16,  ..., 16, 15, 15],
        [ 0,  0,  1,  ...,  3,  3,  3],
        [ 4,  4,  4,  ...,  4,  4,  5]], dtype=torch.int32)

In [14]:
# Test `apply_planar_transform`
batch_planar_features = prsnet_result[0]
test_layer = PRSNet_Symm_Dist_Loss()
test_layer.apply_planar_transform(batch_planar_features, batch_sample_points.to(torch.float32))

tensor([[[[-0.1188, -0.9300, -0.7450],
          [-0.1188, -0.9300, -0.7450],
          [-0.1188, -0.9300, -0.7450]],

         [[-0.1057, -0.9451, -0.7378],
          [-0.1057, -0.9451, -0.7378],
          [-0.1057, -0.9451, -0.7378]],

         [[-0.1068, -0.8868, -0.7299],
          [-0.1068, -0.8868, -0.7299],
          [-0.1068, -0.8868, -0.7299]],

         ...,

         [[-0.0724, -0.7467, -0.6894],
          [-0.0724, -0.7467, -0.6894],
          [-0.0724, -0.7467, -0.6894]],

         [[-0.0519, -0.7669, -0.7142],
          [-0.0519, -0.7669, -0.7142],
          [-0.0519, -0.7669, -0.7142]],

         [[-0.0201, -0.7601, -0.6681],
          [-0.0201, -0.7601, -0.6681],
          [-0.0201, -0.7601, -0.6681]]]], grad_fn=<SubBackward0>)

In [15]:
# Test `apply_quat_rotation`
batch_quat_features = prsnet_result[1]
test_layer = PRSNet_Symm_Dist_Loss()
test_layer.apply_quaternion_rotation(batch_quat_features, batch_sample_points.to(torch.float32))

tensor([[[[ 0.0000e+00,  4.0872e-02,  8.7294e-01],
          [ 1.7753e-09,  3.9647e-02, -8.7185e-01],
          [ 0.0000e+00,  5.7773e-02,  8.4882e-01]],

         [[ 9.3132e-10,  2.5109e-02,  8.8592e-01],
          [ 4.6566e-10,  5.5217e-02, -8.8451e-01],
          [-3.7253e-09,  7.2834e-02,  8.6195e-01]],

         [[-9.3132e-10,  3.6529e-02,  8.2976e-01],
          [ 7.5670e-10,  4.1649e-02, -8.2862e-01],
          [ 0.0000e+00,  5.9755e-02,  8.0619e-01]],

         ...,

         [[ 0.0000e+00,  2.7048e-02,  6.8983e-01],
          [-1.2806e-09,  4.4181e-02, -6.8865e-01],
          [ 0.0000e+00,  6.2569e-02,  6.6785e-01]],

         [[ 0.0000e+00,  3.9269e-03,  7.0363e-01],
          [ 1.1642e-09,  6.9543e-02, -7.0190e-01],
          [-3.7253e-09,  8.8688e-02,  6.8086e-01]],

         [[ 0.0000e+00, -2.7551e-02,  6.9526e-01],
          [-2.3283e-10,  9.7285e-02, -6.9296e-01],
          [ 0.0000e+00,  1.1461e-01,  6.7381e-01]]]], grad_fn=<CloneBackward0>)

In [16]:
# Test `compute_batch_dist_sum`
batch_grid_points = data[0][2].unsqueeze(0)
print(batch_sample_points.shape)
print()

test_layer = PRSNet_Symm_Dist_Loss()
test_layer.compute_batch_dist_sum(batch_grid_points.to(torch.float32), batch_sample_points.to(torch.float32))

torch.Size([1, 1000, 3])



tensor(24.6989)

In [17]:
# Test `Symm_Dist_Loss` computation
prsnet_symm_loss_layer = PRSNet_Symm_Dist_Loss()
prsnet_symm_loss_result = prsnet_symm_loss_layer(batch_planar_features, batch_quat_features, batch_grid_points.to(torch.float32), batch_sample_points.to(torch.float32))
prsnet_symm_loss_result

tensor(1797.0505, grad_fn=<AddBackward0>)

In [18]:
class PRSNet_Reg_Loss(nn.Module):
    '''
    PRS Regularization Loss
    '''
    def __init__(self, ) -> None:
        super().__init__()
    
    def forward(self, batch_planar_features, batch_quat_features):
        # TODO: Compute Regularization Loss
        '''
        `batch_planar_features` should have a shape of `(M, D, 4)`
        '''
        M = batch_planar_features.shape[0]
        D1 = batch_planar_features.shape[1]
        D2 = batch_quat_features.shape[1]
        device = batch_planar_features.device
        
        m1_norm = torch.norm(batch_planar_features[:, :, 0:3], dim=2).reshape(M, D1, 1)
        m1 = (batch_planar_features[:, :, 0:3] / m1_norm)
        
        m1_m1t = torch.einsum('bij, bjk->bik', m1, m1.transpose(1, 2).contiguous())
        m1_id_mat = torch.eye(D1, device=device, requires_grad=True).reshape(1, D1, D1).repeat([M, 1, 1])
        A = m1_m1t - m1_id_mat
        
        
        m2_norm = torch.norm(batch_planar_features[:, :, 1:4], dim=2).reshape(M, D1, 1)
        m2 = (batch_planar_features[:, :, 1:4] / m2_norm)
        
        m2_m2t = torch.einsum('bij, bjk->bik', m2, m2.transpose(1, 2).contiguous())
        m2_id_mat = torch.eye(D2, device=device, requires_grad=True).reshape(1, D1, D1).repeat([M, 1, 1])
        B = m2_m2t - m2_id_mat
        
        loss = torch.norm(A, dim=(1, 2)) + torch.norm(B, dim=(1, 2))
        
        return torch.sum(loss)

In [19]:
prsnet_reg_loss_layer = PRSNet_Reg_Loss()
prsnet_reg_loss_result = prsnet_reg_loss_layer(batch_planar_features, batch_quat_features)
prsnet_reg_loss_result

tensor(4.8990, grad_fn=<SumBackward0>)

In [36]:
class PRSNet_Loss(nn.Module):
    def __init__(self, w_r=25) -> None:
        super().__init__()
        self.symmetry_loss = PRSNet_Symm_Dist_Loss()
        self.reg_loss = PRSNet_Reg_Loss()
        self.w_r = w_r
    
    def forward(self, batch_planar_features, batch_quat_features, batch_grid_points, batch_sample_points):
        symm_loss = self.symmetry_loss(batch_planar_features, batch_quat_features, batch_grid_points, batch_sample_points)
        reg_loss = self.reg_loss(batch_planar_features, batch_quat_features)
        return symm_loss + self.w_r * reg_loss

In [21]:
prsnet_loss_layer = PRSNet_Loss(w_r=25)
prsnet_loss_result = prsnet_loss_layer(batch_planar_features, batch_quat_features, batch_grid_points.to(torch.float32), batch_sample_points.to(torch.float32))
prsnet_loss_result

tensor(1919.5250, grad_fn=<AddBackward0>)

## Overall Test

In [22]:
# device will determine whether to run the training on GPU or CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [24]:
from torch.utils.data import DataLoader, Dataset
class CustomVoxelDataset(Dataset):
    def __init__(self, dataset_lst):
        self.dataset_lst = dataset_lst
    
    def __len__(self):
        return len(self.dataset_lst)
    
    def __getitem__(self, index):
        return self.dataset_lst[index]

train_dataset = CustomVoxelDataset(data)
train_dataset

<__main__.CustomVoxelDataset at 0x7eff9c1b9a50>

In [39]:
dtype = torch.float32
batch_size = 32
num_classes = 10
learning_rate = 0.01
weight_decay = 0.005
num_epochs = 30
w_r = 25

In [25]:
def collate_data_list(raw_dataset):
    mesh_lst = []
    omap_lst = []
    grid_points_lst = []
    offset_vector_lst = []
    sample_points_lst = []
    
    for entry in raw_dataset:
        mesh_lst.append(entry[0])
        omap_lst.append(entry[1].reshape(1, 1, 32, 32, 32))
        grid_points_lst.append(entry[2].reshape(1, 32, 32, 32, 3))
        offset_vector_lst.append(entry[3].reshape(1, 1, 3))
        sample_points_lst.append(entry[4].reshape(1, -1, 3))
        
    # batch_mesh = torch.concat(mesh_lst, dim=0)
    batch_omap = torch.concat(omap_lst, dim=0)
    batch_grid_points = torch.concat(grid_points_lst, dim=0)
    batch_offset_vector = torch.concat(offset_vector_lst, dim=0)
    batch_sample_points = torch.concat(sample_points_lst, dim=0)
    return mesh_lst, batch_omap, batch_grid_points, batch_offset_vector, batch_sample_points

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_data_list)
train_loader

<torch.utils.data.dataloader.DataLoader at 0x7eff9c1b8c10>

In [34]:
import os
if not os.path.exists('../checkpoints/'):
    os.mkdir('../checkpoints')

In [40]:
model = PRSNet()
model.to(device=device, dtype=dtype)

criterion = PRSNet_Loss(w_r)
PATH = '../checkpoints/prsnet.pt'

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

for epoch in range(num_epochs):
    for i, (mesh, omap, grid_points, offset_vector, sample_points) in enumerate(train_loader):
        omap = omap.to(device=device, dtype=dtype)
        grid_points = grid_points.to(device, dtype=dtype)
        offset_vector = offset_vector.to(device, dtype=dtype)
        sample_points = sample_points.to(device, dtype=dtype)
        
        optimizer.zero_grad()
        p_features, q_features = model(omap)
        loss = criterion(p_features, q_features, grid_points, sample_points)
        loss.backward()
        optimizer.step()
    
    print(f'Epoch [{epoch + 1}/{num_epochs}, Loss: {"{:.4f}".format(loss.item())}]')
    torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'w_r': w_r
                }, PATH)

Epoch [1/30, Loss: 3662.2817]
Epoch [2/30, Loss: 3401.3774]
Epoch [3/30, Loss: 2826.7827]
Epoch [4/30, Loss: 2072.0676]
Epoch [5/30, Loss: 2603.5547]
Epoch [6/30, Loss: 1386.7670]
Epoch [7/30, Loss: 1386.2344]
Epoch [8/30, Loss: 869.6987]
Epoch [9/30, Loss: 817.9903]
Epoch [10/30, Loss: 935.2198]
Epoch [11/30, Loss: 758.9705]
Epoch [12/30, Loss: 820.9611]
Epoch [13/30, Loss: 762.6750]
Epoch [14/30, Loss: 761.2150]
Epoch [15/30, Loss: 764.3408]
Epoch [16/30, Loss: 722.7332]
Epoch [17/30, Loss: 711.6592]
Epoch [18/30, Loss: 707.3466]
Epoch [19/30, Loss: 706.5849]
Epoch [20/30, Loss: 714.3914]
Epoch [21/30, Loss: 719.8852]
Epoch [22/30, Loss: 720.5227]
Epoch [23/30, Loss: 721.1401]
Epoch [24/30, Loss: 721.2773]
Epoch [25/30, Loss: 716.5581]
Epoch [26/30, Loss: 714.8130]
Epoch [27/30, Loss: 707.8404]
Epoch [28/30, Loss: 702.4996]
Epoch [29/30, Loss: 701.0557]
Epoch [30/30, Loss: 698.0577]


In [31]:
model

PRSNet(
  (encoder): PRSNet_Encoder(
    (conv_layer0): Conv3d(1, 4, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (max_pool0): MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
    (leaky_relu0): LeakyReLU(negative_slope=0.2)
    (conv_layer1): Conv3d(4, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (max_pool1): MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
    (leaky_relu1): LeakyReLU(negative_slope=0.2)
    (conv_layer2): Conv3d(8, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (max_pool2): MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
    (leaky_relu2): LeakyReLU(negative_slope=0.2)
    (conv_layer3): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (max_pool3): MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
    (leaky_re

##### Load the Model

In [41]:
model1 = PRSNet()
criterion = PRSNet_Loss()
optimizer = torch.optim.Adam(model1.parameters(), lr=learning_rate, weight_decay=weight_decay)

checkpoint = torch.load(PATH)
model1.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
opoch = checkpoint['epoch']
w_r = checkpoint['w_r']

model1.train()
model1

PRSNet(
  (encoder): PRSNet_Encoder(
    (conv_layer0): Conv3d(1, 4, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (max_pool0): MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
    (leaky_relu0): LeakyReLU(negative_slope=0.2)
    (conv_layer1): Conv3d(4, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (max_pool1): MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
    (leaky_relu1): LeakyReLU(negative_slope=0.2)
    (conv_layer2): Conv3d(8, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (max_pool2): MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
    (leaky_relu2): LeakyReLU(negative_slope=0.2)
    (conv_layer3): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (max_pool3): MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
    (leaky_re