# PointNet Part Segmentation

## Imports

In [114]:
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# import open3d as o3d

from pathlib import Path
import os

In [115]:
if (torch.cuda.is_available()):
    print(torch.cuda.device_count(), torch.cuda.current_device())

## Visualizing Data

In [116]:
# def visualize(pts, labels=None):

#     is_labeled = False
#     if type(labels) != type(None):
#         is_labeled = True

#     try:
#         pts = pts.numpy()
#         labels = labels.numpy()
#         print('Tensor converted to numpy')
#     except:
#         print('Passed numpy')

#     # red, green, blue, purple
#     colors = np.array([[255,0,0],[0,255,0],[0,0,255],[100,0,100]])
#     if is_labeled: labels = colors[labels-1]

#     pcd = o3d.geometry.PointCloud()
#     pcd.points = o3d.utility.Vector3dVector(pts)
#     if is_labeled: pcd.colors = o3d.utility.Vector3dVector(labels)

#     o3d.visualization.draw_geometries([pcd])

In [117]:
eg_pts_hash = '1d1b37ce6d72d7855096c0dd2594842a'
eg_pts_path = f'data/train/pts/{eg_pts_hash}.pts'
eg_labels_path = f'data/train/label/{eg_pts_hash}.txt'

ex_pc = np.loadtxt(eg_pts_path, delimiter=' ')
ex_labels = np.loadtxt(eg_labels_path, delimiter=' ', dtype=np.int8)

print(ex_pc.shape, ex_labels.shape)

(2704, 3) (2704,)


In [118]:
# Uncomment to run, requires system GUI
# visualize(torch.from_numpy(ex_pc), torch.from_numpy(ex_labels))

## Custom Dataset + Make DataLoaders

In [119]:
class ChairCloudDataset(Dataset):
    def __init__(self, root: str, train: bool=True, transform=torch.from_numpy, target_transform=torch.from_numpy):
        self.train = train

        self.transform = transform
        self.target_transform = target_transform

        root = Path(root)
        pts_dir = None
        labels_dir = None
        if self.train:
            pts_dir = root / Path('pts')
            labels_dir = root / Path('label')
        else:
            pts_dir = Path(root)

        self.pts_paths = []
        self.label_paths = []
        for pts_file in os.listdir(pts_dir):
            self.pts_paths.append(pts_dir / pts_file)

            if self.train:
                self.label_paths.append(labels_dir / (Path(pts_file).stem + '.txt'))

        dataset_type = 'train' if self.train else 'test'
        print(f'Found {len(self.pts_paths)} {dataset_type} datapoints')

    def __getitem__(self, index):
        pts_file = self.pts_paths[index]
        pts = np.loadtxt(pts_file, delimiter=' ')
        pts = self.transform(pts)

        if self.train:
            label_file = self.label_paths[index]
            seg = np.loadtxt(label_file, delimiter=' ', dtype=np.uint8)
            seg = self.target_transform(seg)

            return pts.float(), seg
        
        return pts.float()

    def __len__(self):
        return len(self.pts_paths)
            



In [120]:
train_data = ChairCloudDataset('data/train', train=True)
test_data = ChairCloudDataset('data/test', train=False)

Found 1000 train datapoints
Found 6 test datapoints


In [121]:
# Uncomment to visualize
# visualize(*train_data[0])

In [122]:
from torch.nn.utils.rnn import pad_sequence

def pad_test(batch):
    return pad_sequence(batch, batch_first=True, padding_value=0)

def pad_train(batch):
    pts, seg = zip(*batch)
    
    pts_padded = pad_sequence(pts, batch_first=True, padding_value=0.0)
    seg_padded = pad_sequence(seg, batch_first=True, padding_value=0)

    return pts_padded, seg_padded

In [123]:
# train_dl = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=pad_train, num_workers=2)
# test_dl = DataLoader(test_data, batch_size=32, shuffle=True, collate_fn=pad_test, num_workers=2)

## Building PointNet Model

First, we must construct the input and feature transform steps. This consistes of a T-Net with $k$-long input, following by a matmul.

We will begin by implementing the T-Nets for general $k$-long input (we will later use once with nx3 input for input transform, then once for nx64 input for feature transform).

<img src='./imgs/annotated-tnet.jpg'>

In [124]:

import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

class TNet(nn.Module):
    def __init__(self, k: int):
        super().__init__()

        self.k = k

        # 1st convolution + "mlp"
        self.conv1 = nn.Conv1d(k, 64, 1)
        self.batchnorm1 = nn.BatchNorm1d(64)
        # relu called directly

        # 2nd convolution + "mlp"
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.batchnorm2 = nn.BatchNorm1d(128)
        # relu called directly

        # 3rd convolution + "mlp"
        self.conv3 = nn.Conv1d(128, 1024, 1)
        self.batchnorm3 = nn.BatchNorm1d(1024)
        # relu called directly

        # max pool called directly

        # 1st fully connected layer after max pool
        self.fc1 = nn.Linear(1024, 512)
        self.batchnorm4 = nn.BatchNorm1d(512)
        # relu called directly

        # 2nd fully connected layer after max pool
        self.fc2 = nn.Linear(512, 256)
        self.batchnorm5 = nn.BatchNorm1d(256)
        # relu called directly

        # generate final output tensor for matmul
        self.fc3 = nn.Linear(256, self.k * self.k)

    def forward(self, input):
        batch_size = input.size(0)

        # 1st convolution
        x = F.relu(self.batchnorm1(self.conv1(input)))
        # 2nd convolution
        x = F.relu(self.batchnorm2(self.conv2(x)))
        # 3rd convolution
        x = F.relu(self.batchnorm3(self.conv3(x)))

        # max pool
        x = nn.MaxPool1d(x.size(-1))(x)
        x = nn.Flatten(1)(x)

        # 1st fully connected layer after max pool
        x = F.relu(self.batchnorm4(self.fc1(x)))
        # 2nd fully connected layer after max pool
        x = F.relu(self.batchnorm5(self.fc2(x)))

        # final fc before reshaping to output matrix
        x = self.fc3(x)

        # init matrix to identity for orthogonality
        matrix = torch.eye(self.k, requires_grad=True).flatten().repeat(batch_size, 1)
        if matrix.is_cuda:
            matrix = matrix.cuda()
        # add to last fc layer
        matrix = x + matrix
        # reshape to batch_size x 3 x 3
        matrix = matrix.view(batch_size, self.k, self.k)
        
        return matrix

Now we can take the TNet and place it in the greater PointNet, completing the transformations as well

In [125]:
class PointNet(nn.Module):
    def __init__(self, num_classes = 5):
        super().__init__()

        self.num_classes = num_classes

        # Input Transform TNet
        self.input_tnet = TNet(3)

        # 1st shared mlp between input and feature transform steps
        self.conv1 = nn.Conv1d(3, 64, 1)
        self.batchnorm1 = nn.BatchNorm1d(64)
        # relu called directly

        # Feature Trasformation TNet
        self.feature_tnet = TNet(64)

        # 2nd shared mlp, 1st convolution
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.batchnorm2 = nn.BatchNorm1d(128)
        # relu called directly

        # 2nd shared mlp, 2nd convolution
        self.conv3 = nn.Conv1d(128, 1024, 1)
        self.batchnorm3 = nn.BatchNorm1d(1024)
        # relu called directly


        # segmentation network, shared mlp, convolution 1
        self.conv4 = nn.Conv1d(1088, 512, 1)
        self.batchnorm4 = nn.BatchNorm1d(512)
        # relu called directly

        # segmentation network, shared mlp, convolution 2
        self.conv5 = nn.Conv1d(512, 256, 1)
        self.batchnorm5 = nn.BatchNorm1d(256)
        # relu called directly

        # segmentation network, shared mlp, convolution 3
        self.conv6 = nn.Conv1d(256, 128, 1)
        self.batchnorm6 = nn.BatchNorm1d(128)
        # relu called directly

        # segmentation network, shared mlp, convolution 4
        self.conv7 = nn.Conv1d(128, self.num_classes, 1)
        # no relu or batchnorm bc this will be the output (after logsoftmax)

        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, input):
        num_pts = input.size(-1)

        # input transformation
        mat3x3 = self.input_tnet(input)

        x = input.transpose(2, 1)       # align dims
        x = torch.bmm(x, mat3x3)
        x = x.transpose(2, 1)           # put channels back in correct spot

        # 1st shared mlp between input and feature transform steps
        x = F.relu(self.batchnorm1(self.conv1(x)))

        # feature transformation
        mat64x64 = self.feature_tnet(x)
        x = x.transpose(2, 1)
        x = torch.bmm(x, mat64x64)
        x = x.transpose(2, 1)

        # save for segmentation network later
        feature_matrix = x

        # 2nd shared mlp convolutions
        x = F.relu(self.batchnorm2(self.conv2(x)))
        x = F.relu(self.batchnorm3(self.conv3(x)))

        # Max Pool for symmmetric func / perm invariance
        x = nn.MaxPool1d(x.size(-1))(x)
        # repeat to later join with feature matrix
        x = x.repeat(1, 1, num_pts)
        # begin segmentation network section
        seg_x = torch.cat((feature_matrix, x), 1)

        # segmentation network, shared mlp convolutions
        seg_x = F.relu(self.batchnorm4(self.conv4(seg_x)))
        seg_x = F.relu(self.batchnorm5(self.conv5(seg_x)))
        seg_x = F.relu(self.batchnorm6(self.conv6(seg_x)))

        # this is our pointwise class pred
        seg_x = self.conv7(seg_x)
        seg_x = seg_x.transpose(2,1)
        seg_x = self.logsoftmax(seg_x)

        return seg_x, mat64x64
        

## Training Functions

In [126]:
def train_step(pnet, train_dl, optimizer, epoch = 0, reg_strength = 0.001, num_classes=5):

    # paper descrived regularization of form ||I - AA^T||^2, where A is the feature transformed matrix (64x64)
    def regularize(m64):
        I2 = torch.flatten(torch.eye(10), start_dim=1)
        if m64.is_cuda:
            I = I.cuda()
        loss = torch.mean(torch.norm(I - torch.bmm(m64, m64.transpose(2,1)), dim=(1,2)) ** 2)
        return loss

    # train in batches
    batch = 0
    for data in iter(train_dl):
        batch += 1

        # prep data
        pts, label = data
        pts = pts.transpose(2, 1)
        if torch.cuda.is_available():
            points, label = points.cuda(), label.cuda()

        # zero out grads and get model ready to train
        optimizer.zero_grad()
        pnet = pnet.train()

        # get predictions
        pred, m64 = pnet(pts)
        pred = pred.view(-1, num_classes)
        label = torch.flatten(label)

        # calc loss
        loss = F.nll_loss(pred, label) + reg_strength * regularize(m64)
        loss.backward()

        # descent step
        optimizer.step()

        # calc accuracy of pred
        pred_choice = pred.data.max(1)[1]
        correct = pred_choice.eq(label).cpu().sum()
        accuracy = correct.item()/label.size(0)

        print(f'epoch: {epoch}\tbatch: {batch}/{len(train_dl)}\taccuracy: {accuracy}\tloss: {loss.item()}')

def train(train_dl, max_epochs=100, acc_req=0.9, lr=0.001, reg_strength=0.001, num_classes=5):

    # get device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # init pointnet
    pnet = PointNet(num_classes=num_classes)
    pnet.to(device)
    if torch.cuda.is_available():
        pnet.cuda()

    # using adam for faster convergence
    optimizer = torch.optim.Adam(pnet.parameters(), lr=lr)

    # run train loop
    for epoch in range(max_epochs):

        # train in batches
        epoch_accuracy = train_step(pnet, train_dl, optimizer, epoch=epoch, reg_strength=reg_strength, num_classes=num_classes)

        # break if req accuracy reached already
        if (epoch_accuracy > acc_req):
            break

In [127]:
BATCH_SIZE = 32     # 64 blew up my laptop :(
MAX_EPOCHS = 100    # absolute max, if we reach 100 i've messed up greatly
ACC_REQ = 0.9       # should be good enough
LR = 0.001
REG_STRENGTH = 0.001

In [128]:
train_dl = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_train)
test_dl = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_test)

In [129]:
train(train_dl, max_epochs=MAX_EPOCHS, acc_req=ACC_REQ, lr=LR, reg_strength=REG_STRENGTH)

epoch: 0	batch: 1/32	accuracy: 0.21553050281293953	loss: 8.10023307800293
epoch: 0	batch: 2/32	accuracy: 0.413625845496618	loss: 8.013431549072266
epoch: 0	batch: 3/32	accuracy: 0.5695557414582599	loss: 7.887115955352783
epoch: 0	batch: 4/32	accuracy: 0.6088152173913044	loss: 7.77882194519043
epoch: 0	batch: 5/32	accuracy: 0.6258943965517242	loss: 7.712094783782959
epoch: 0	batch: 6/32	accuracy: 0.6615980063514467	loss: 7.64022970199585
epoch: 0	batch: 7/32	accuracy: 0.6980481072555205	loss: 7.603155612945557
epoch: 0	batch: 8/32	accuracy: 0.7146908555594043	loss: 7.570329189300537
epoch: 0	batch: 9/32	accuracy: 0.7311794781382228	loss: 7.511545181274414
epoch: 0	batch: 10/32	accuracy: 0.6751543209876543	loss: 7.523986339569092
epoch: 0	batch: 11/32	accuracy: 0.705314595834804	loss: 7.464018821716309
epoch: 0	batch: 12/32	accuracy: 0.7341193990042674	loss: 7.405744552612305
epoch: 0	batch: 13/32	accuracy: 0.744890143557423	loss: 7.3918776512146
epoch: 0	batch: 14/32	accuracy: 0.7237948

KeyboardInterrupt: 