# Installs

In [None]:
!pip install tensorboard
!pip install open3d

# Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.tensorboard import SummaryWriter

import numpy as np

import open3d as o3d

import json
import os
import matplotlib.pyplot as plt

%load_ext tensorboard

In [None]:
try:
    import pytorch3d
except ModuleNotFoundError:
    need_pytorch3d=True
if need_pytorch3d:
    if torch.__version__.startswith("1.7") and sys.platform.startswith("linux"):
        # We try to install PyTorch3D via a released wheel.
        version_str="".join([
            f"py3{sys.version_info.minor}_cu",
            torch.version.cuda.replace(".",""),
            f"_pyt{torch.__version__[0:5:2]}"
        ])
        !pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
    else:
        # We try to install PyTorch3D from source.
        !curl -LO https://github.com/NVIDIA/cub/archive/1.10.0.tar.gz
        !tar xzf 1.10.0.tar.gz
        os.environ["CUB_HOME"] = os.getcwd() + "/cub-1.10.0"
        !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'


In [None]:
from pytorch3d.transforms import so3_relative_angle, matrix_to_quaternion, quaternion_to_matrix

# Network Architecture

## Input Transform Layer

This layer corresponds to the input transform layer as described in PointNet architecture

In [None]:
class TNet3(nn.Module):
    '''
    T-Net implementation with 3x3 transform as output
    '''
    def __init__(self):
        super(TNet3,self).__init__()
        # shared MLP
        self.conv1 = nn.Conv1d(3,64,1)
        self.conv2 = nn.Conv1d(64,128,1)
        self.conv3 = nn.Conv1d(128,1024,1)
        # fc layers
        self.fc1 = nn.Linear(1024,512)
        self.fc2 = nn.Linear(512,256)
        # output layer
        self.fc3 = nn.Linear(256,9)
        # batch norm layers
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)
    
    def forward(self, x):
        batchsize = x.size()[0]
        # shared MLP
        x = F.relu((self.conv1(x)))
        x = F.relu((self.conv2(x)))
        x = F.relu((self.conv3(x)))
        # max pool
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)
        # FC layers
        x = F.relu((self.fc1(x)))
        x = F.relu((self.fc2(x)))
        # output layer
        x = self.fc3(x)
        # 3x3 output matrix
        iden = torch.eye(3, dtype=x.dtype, device=x.device).view(1,9).repeat(batchsize, 1)
        x = x + iden
        x = x.view(-1,3,3)
        return x

## Feature Transform Layer

This layer corresponds to the input transform layer as described in PointNet architecture

In [None]:
class TNet64(nn.Module):
    '''
    T-Net implementation with 3x3 transform as output
    '''
    def __init__(self):
        super(TNet64,self).__init__()
        # shared MLP
        self.conv1 = nn.Conv1d(64,64,1)
        self.conv2 = nn.Conv1d(64,128,1)
        self.conv3 = nn.Conv1d(128,1024,1)
        # fc layers
        self.fc1 = nn.Linear(1024,512)
        self.fc2 = nn.Linear(512,256)
        # output layer
        self.fc3 = nn.Linear(256,64*64)
        # batch norm layers
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)
    
    def forward(self, x):
        batchsize = x.size()[0]
        # shared MLP
        x = F.relu((self.conv1(x)))
        x = F.relu((self.conv2(x)))
        x = F.relu((self.conv3(x)))
        # max pool
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)
        # FC layers
        x = F.relu((self.fc1(x)))
        x = F.relu((self.fc2(x)))
        # output layer
        x = self.fc3(x)
        # 3x3 output matrix
        iden = torch.eye(64, dtype=x.dtype, device=x.device).view(1,64*64).repeat(batchsize, 1)
        x = x + iden
        x = x.view(-1,64,64)
        return x

## PointNet Features

Implements PointNet architecture for extracting global features. The commented lines show the parts of PointNet that have been modified in OBBNet for better results.

In [77]:
class PointNet(nn.Module):
    '''
    1024 dimensional feature as output
    '''
    def __init__(self,has_input_tf=True,has_feature_tf=True):
        super(PointNet,self).__init__()
        # flags to control the transformation layers
        self.has_input_tf = has_input_tf
        self.has_feature_tf = has_feature_tf
        # input transform
        self.tnet3 = TNet3()
        # feature transform
        self.tnet64 = TNet64()

        # shared MLP(64,128,1024)
        self.conv21 = nn.Conv1d(3,64,1)
        self.conv22 = nn.Conv1d(64,128,1)
        self.conv23 = nn.Conv1d(128,1024,1)
    
    def forward(self, x):
        '''
        input will be of the form (batch_size,3,n)
        '''
        # input transform
        if self.has_input_tf:
            input_tf = self.tnet3(x)
            # matrix multiply
            x = x.transpose(2,1) # shape: (batch_size,n,3)
            x = torch.bmm(x, input_tf) # shape: (batch_size,n,3)
            x = x.transpose(2,1) # shape: (batch_size,3,n)


        # feature transform
        if self.has_feature_tf:
            feature_tf = self.tnet64(x)
            # matrix multiply
            x = x.transpose(2,1) # shape: (batch_size,n,64)
            x = torch.bmm(x, feature_tf) # shape: (batch_size,n,64)
            x = x.transpose(2,1) # shape: (batch_size,64,n)
        
        # shared MLP(64,128,1024)
        x = F.relu((self.conv21(x)))
        x = F.relu((self.conv22(x)))
        x = (self.conv23(x))
        
        # maxpool
        x = torch.max(x, 2, keepdim=True)[0]
        
        # global feature
        x = x.view(-1,1024)
        
        return x

## OBB Prediction

Extension of PointNet to extract bounding box parameters

In [78]:
class OBBNet(nn.Module):
    '''
    Outputs OBB
    '''
    def __init__(self,has_input_tf, has_feature_tf):
        super(OBBNet, self).__init__()
        self.has_input_tf = has_input_tf
        self.has_feature_tf = has_feature_tf
        
        # PointNet global features
        self.pt_fts = PointNet(self.has_input_tf, self.has_feature_tf)
        
        # fc layers
        self.fc1 = nn.Linear(1024,512)
        self.fc2 = nn.Linear(512,256)
        self.fc3 = nn.Linear(256,6)
        self.fc4 = nn.Linear(1024,512)
        self.fc5 = nn.Linear(512,256)
        self.fc6 = nn.Linear(256,4)
    
    def forward(self, x):
        x = self.pt_fts(x)
        # regress center and extent
        x1 = F.relu(self.fc1(x))
        x1 = F.relu(self.fc2(x1))
        x1 = (self.fc3(x1))
        # regress orientation
        x2 = F.relu(self.fc4(x))
        x2 = F.relu(self.fc5(x2))
        x2 = (self.fc6(x2))

        return x1,x2

# Loading Data

In [79]:
# custom Dataset class
class OBBDataset(Dataset):
    def __init__(self,jsonfilename,rootdir):
        '''
        jsonfilename : filename of the json file containing 
                       ground truth labels
        rootdir : root directory of the dataset
        '''
        with open(jsonfilename) as fp:
            self.obb_json = json.load(fp)
        self.filenames = list(self.obb_json.keys())
        self.rootdir = rootdir
    
    def __len__(self):
        return len(self.filenames)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        # load the corresponding point cloud
        filename = self.filenames[idx]
        filepath = os.path.join(self.rootdir,filename)
        pcd = o3d.io.read_point_cloud(filepath)
        pcd_points = np.asarray(pcd.points).astype(np.float32)

        # get obb
        # obb = self.obb_json[filename]
        obb = pcd.get_oriented_bounding_box()
        obb_center = np.array(obb.center, dtype=np.float32)
        obb_extent = np.array(obb.extent, dtype=np.float32)
        obb_r = np.array(obb.R, dtype = np.float32)#.reshape(3,3)
        
        # # convert to quarternions
        obb_center = torch.tensor(obb_center)
        obb_extent = torch.tensor(obb_extent)
        obb_r = torch.tensor(obb_r)
        obb_q = matrix_to_quaternion(obb_r)
        obb_params = torch.cat((obb_center, obb_extent, obb_q))
        # obb_params = torch.cat((obb_center, obb_extent, obb_r))

        return pcd_points, obb_params

Load training and test data

In [80]:
train_labels = '/content/drive/MyDrive/COMP0119/Datasets/shapenet_train_labels.json'
test_labels = '/content/drive/MyDrive/COMP0119/Datasets/shapenet_test_labels.json'
rootdir = '/content/drive/MyDrive/COMP0119/Datasets/shapenet5k'

train_set = OBBDataset(jsonfilename=train_labels,
                       rootdir=rootdir)
test_set = OBBDataset(jsonfilename=test_labels,
                      rootdir=rootdir)

batch_size = 32

train_loader = DataLoader(train_set,batch_size=batch_size)
test_loader = DataLoader(test_set,batch_size=batch_size)

# Training Model

In [81]:
# get GPU if possible
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [82]:
# define network object
net = OBBNet(has_input_tf = False, has_feature_tf = False)
net.to(device)
# MSE loss
criterion = nn.MSELoss(reduction='sum')

In [83]:
# optimizer 
optimizer = optim.Adam(net.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

In [84]:
def qloss(t1,t2):
    '''
    computes the geodesic distance between a quarternion pair
    '''
    t = torch.mul(t1,t2).sum(dim=1)**2
    t = 1-t
    return t.norm()**2

In [85]:
blue = lambda x: '\033[94m' + x + '\033[0m'

In [None]:
num_epochs = 100
lossRecord = []
for epoch in range(num_epochs):
    # print('Epoch: ',epoch)
    net.train()
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs,labels = data
        inputs,labels = inputs.to(device),labels.to(device)
        inputs = inputs.transpose(2,1)
        optimizer.zero_grad()
        output1,output2 = net(inputs)
        # loss_1 = criterion(outputs[:,:3],labels[:,:3])
        # loss_2 = criterion(outputs[:,3:6],labels[:,3:6])
        # loss_3 = qloss(outputs[:,6:],labels[:,6:])
        # penalty = (outputs[:,6:].norm(dim=1)**2 - 1)**2
        # loss = loss_1 + 5*loss_2 + loss_3 + 10*penalty.sum() 
        d_loss = criterion(output1,labels[:,:6])
        penalty = (output2.norm(dim=1)**2 - 1)**2
        q_loss = qloss(output2,labels[:,6:]) + 10*penalty.sum()
        loss = 10*d_loss + q_loss 
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    lossRecord.append(running_loss / len(train_loader))
    # print training metrics
    print('Epoch: %d | Loss: %.5f' %
            (epoch + 1, running_loss / len(train_loader)))
    
    # test performance on test set every 10 epochs
    if ((epoch+1)%10==0):
        net.eval()
        test_itr = iter(test_loader)
        inputs,labels = next(test_itr)
        inputs,labels = inputs.to(device),labels.to(device)
        inputs = inputs.transpose(2,1)
        output1,output2 = net(inputs)
        # loss_1 = criterion(outputs[:,:3],labels[:,:3])
        # loss_2 = criterion(outputs[:,3:6],labels[:,3:6])
        # loss_3 = qloss(outputs[:,6:],labels[:,6:])
        # train_loss = (loss_1 + loss_2 + loss_3).item()
        d_loss = criterion(output1,labels[:,:6])
        q_loss = qloss(output2,labels[:,6:]) 
        train_loss = 10*d_loss + q_loss
        print('%s Loss: %.5f' %
            (blue('Test'),train_loss))
    # writer.add_scalar('Loss/train',running_loss / len(train_loader), epoch)
    scheduler.step()

plt.figure()
plt.plot(lossRecord)
plt.show()

# Results

In [None]:
torch.save(net.to(torch.device('cpu')).state_dict(),'net_5k.pth')

In [None]:
with open(test_labels,'r') as fp:
    obb_dict = json.load(fp)

filenames = list(obb_dict.keys())

In [None]:
# on test set
net.to(torch.device('cpu'))
criterion = nn.MSELoss()
test_loader = DataLoader(test_set,batch_size=1)

net.eval()
i = 0

test_res = {}

with torch.no_grad():
    for data in test_loader:
        inputs, labels = data
        inputs,labels = inputs,labels
        inputs = inputs.transpose(2,1)
        output1,output2 = net(inputs)
        
        d_loss = criterion(output1,labels[:,:6])
        q_loss = qloss(output2,labels[:,6:]) 
        loss = d_loss + q_loss
        loss = loss.item()

        obb_center = output1[:,:3].flatten().tolist()
        obb_extent = output1[:,3:6].flatten().tolist()
        obb_r = quaternion_to_matrix(output2)
        obb_r = obb_r.view(3,3).tolist()

        # save output into dictionary
        test_res[filenames[i]] = {
        'center':obb_center,
        'extent':obb_extent,
        'R':obb_r
        }

        i += 1

In [None]:
# save results into a json file
with open('res_shapenet.json','w') as fp:
    json.dump(test_res,fp)

## Visualization

In [60]:
with open('res_shapenet.json','r') as fp:
    test_res = json.load(fp)

In [61]:
def evaluate(filename,data_root,disp=False):
    filepath = os.path.join(data_root,filename)
    pcd = o3d.io.read_point_cloud(filepath)
    # axis aligned bb
    aabb = pcd.get_axis_aligned_bounding_box()
    aabb_line_set = o3d.geometry.LineSet.create_from_axis_aligned_bounding_box(aabb)
    aabb_line_set.paint_uniform_color([1,0,0])
    # true obb
    obb = obb_dict[filename]
    obb_center = np.array(obb['center'], dtype=np.float32).reshape(3,1)
    obb_extent = np.array(obb['extent'], dtype=np.float32).reshape(3,1)
    obb_r = np.array(obb['R'], dtype = np.float32).reshape(3,3)
    obb = o3d.geometry.OrientedBoundingBox(center=obb_center, extent=obb_extent, R=obb_r)
    obb_line_set = o3d.geometry.LineSet.create_from_oriented_bounding_box(obb)
    obb_line_set.paint_uniform_color([1,0,0])
    # predicted obb
    pred_obb = test_res[filename]
    pred_obb_center = np.array(pred_obb['center'], dtype=np.float32).reshape(3,1)
    pred_obb_extent = np.array(pred_obb['extent'], dtype=np.float32).reshape(3,1)
    pred_obb_r = np.array(pred_obb['R'], dtype = np.float32).reshape(3,3)
    pred_obb = o3d.geometry.OrientedBoundingBox(
        center=pred_obb_center,
        extent=pred_obb_extent+0.05, 
        R=pred_obb_r)
    pred_obb_line_set = o3d.geometry.LineSet.create_from_oriented_bounding_box(pred_obb)
    pred_obb_line_set.paint_uniform_color([0,0,1])
    if disp:
        o3d.visualization.draw_geometries([pcd,pred_obb_line_set,obb_line_set])
    # 
    print(obb_r)
    print(pred_obb_r)

In [62]:
test_filenames = list(test_res.keys())

In [None]:
idx = 41
filename = test_filenames[idx]
print(filename)
evaluate(filename,rootdir,disp=1)