# PointNet Part Segmentation

Feel free to check out my WandB for some charts and pointcloud visualizations (though there are pointcluod vis. below as well)

https://wandb.ai/arth-shukla/Pointnet%20Chair%20Part%20Segmentation?workspace=user-arth-shukla

## Setup

**NOTE IF YOU"RE USING COLAB**: Colab's native installation for torchvision has conflicting PIL version deps, so after running the install for open3d, please restart your runtime (without deleting it)

This shouldn't be an issue for installing everything locally

In [None]:
# please make sure data is "data.zip"
# alternatively, you can drag and drop directly

# uncomment if you're on colab
# from google.colab import files
# uploaded = files.upload()

In [1]:
# Uncomment to unzip on Google Colab
# !unzip data.zip -d data

In [2]:
# Uncomment to install Wandb in Google Colab
# !pip install wandb

In [3]:
# Uncomment to install open3d on Colab
# !pip install open3d

**IF USING *COLAB*, COMMENT OUT THE ABOVE CELLS, THEN RESTART RUNTIME HERE**

In [4]:
import wandb
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33marth-shukla[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [5]:
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

from pathlib import Path
import os

In [6]:
if (torch.cuda.is_available()):
    print(torch.cuda.device_count(), torch.cuda.current_device())

1 0


## Visualizing Data

In [7]:
import open3d as o3d

def visualize(pts, labels=None):

    is_labeled = False
    if type(labels) != type(None):
        is_labeled = True

    try:
        pts = pts.numpy()
        labels = labels.numpy()
        print('Tensor converted to numpy')
    except:
        print('Passed numpy')

    # red, green, blue, purple
    colors = np.array([[0,0,0],[255,0,0],[0,255,0],[0,0,255],[100,0,100]])
    if is_labeled: labels = colors[labels]
    
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(pts)
    if is_labeled: pcd.colors = o3d.utility.Vector3dVector(labels)

    o3d.visualization.draw_plotly([pcd])

In [8]:
eg_pts_hash = '1d1b37ce6d72d7855096c0dd2594842a'
eg_pts_path = f'data/train/pts/{eg_pts_hash}.pts'
eg_labels_path = f'data/train/label/{eg_pts_hash}.txt'

ex_pc = np.loadtxt(eg_pts_path, delimiter=' ')
ex_labels = np.loadtxt(eg_labels_path, delimiter=' ', dtype=np.int8)

print(ex_pc.shape, ex_labels.shape)

(2704, 3) (2704,)


In [9]:
# Uncomment to run and see a train datapoint example
# visualize(torch.from_numpy(ex_pc), torch.from_numpy(ex_labels))

## Custom Dataset + Make DataLoaders

In [10]:
class ChairCloudDataset(Dataset):
    def __init__(self, root: str, train: bool=True, transform=torch.from_numpy, target_transform=torch.from_numpy):
        self.train = train

        self.transform = transform
        self.target_transform = target_transform

        root = Path(root)
        pts_dir = None
        labels_dir = None
        if self.train:
            pts_dir = root / Path('pts')
            labels_dir = root / Path('label')
        else:
            pts_dir = Path(root)

        self.pts_paths = []
        self.label_paths = []
        for pts_file in os.listdir(pts_dir):
            self.pts_paths.append(pts_dir / pts_file)

            if self.train:
                self.label_paths.append(labels_dir / (Path(pts_file).stem + '.txt'))

        dataset_type = 'train' if self.train else 'test'
        print(f'Found {len(self.pts_paths)} {dataset_type} datapoints')

    def __getitem__(self, index):
        pts_file = self.pts_paths[index]
        pts = np.loadtxt(pts_file, delimiter=' ')
        pts = self.transform(pts)

        if self.train:
            label_file = self.label_paths[index]
            seg = np.loadtxt(label_file, delimiter=' ', dtype=np.uint8)
            seg = self.target_transform(seg)

            return pts.float(), seg
        
        return pts.float()

    def __len__(self):
        return len(self.pts_paths)

In [11]:
train_data = ChairCloudDataset('data/train', train=True)
test_data = ChairCloudDataset('data/test', train=False)

Found 1000 train datapoints
Found 6 test datapoints


In [12]:
# Uncomment to visualize
# visualize(*train_data[0])

In [13]:
from torch.nn.utils.rnn import pad_sequence

def pad_test(batch):
    return pad_sequence(batch, batch_first=True, padding_value=1)

def pad_train(batch):
    pts, seg = zip(*batch)
    
    pts_padded = pad_sequence(pts, batch_first=True, padding_value=1)
    seg_padded = pad_sequence(seg, batch_first=True, padding_value=1)

    return pts_padded, seg_padded

In [14]:
# train_dl = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=pad_train, num_workers=2)
# test_dl = DataLoader(test_data, batch_size=32, shuffle=True, collate_fn=pad_test, num_workers=2)

## Building PointNet Model

First, we must construct the input and feature transform steps. This consistes of a T-Net with $k$-long input, following by a matmul.

We will begin by implementing the T-Nets for general $k$-long input (we will later use once with nx3 input for input transform, then once for nx64 input for feature transform).

<img src='./imgs/annotated-tnet.jpg'>

In [15]:

import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

class TNet(nn.Module):
    def __init__(self, k: int):
        super().__init__()

        self.k = k

        # 1st convolution + "mlp"
        self.conv1 = nn.Conv1d(k, 64, 1)
        self.batchnorm1 = nn.BatchNorm1d(64)
        # relu called directly

        # 2nd convolution + "mlp"
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.batchnorm2 = nn.BatchNorm1d(128)
        # relu called directly

        # 3rd convolution + "mlp"
        self.conv3 = nn.Conv1d(128, 1024, 1)
        self.batchnorm3 = nn.BatchNorm1d(1024)
        # relu called directly

        # max pool called directly

        # 1st fully connected layer after max pool
        self.fc1 = nn.Linear(1024, 512)
        self.batchnorm4 = nn.BatchNorm1d(512)
        # relu called directly

        # 2nd fully connected layer after max pool
        self.fc2 = nn.Linear(512, 256)
        self.batchnorm5 = nn.BatchNorm1d(256)
        # relu called directly

        # generate final output tensor for matmul
        self.fc3 = nn.Linear(256, self.k * self.k)

    def forward(self, input):
        batch_size = input.size(0)

        # 1st convolution
        x = F.relu(self.batchnorm1(self.conv1(input)))
        # 2nd convolution
        x = F.relu(self.batchnorm2(self.conv2(x)))
        # 3rd convolution
        x = F.relu(self.batchnorm3(self.conv3(x)))

        # max pool
        x = nn.MaxPool1d(x.size(-1))(x)
        x = nn.Flatten(1)(x)

        # 1st fully connected layer after max pool
        x = F.relu(self.batchnorm4(self.fc1(x)))
        # 2nd fully connected layer after max pool
        x = F.relu(self.batchnorm5(self.fc2(x)))

        # final fc before reshaping to output matrix
        x = self.fc3(x)

        # init matrix to identity for orthogonality
        matrix = torch.eye(self.k, requires_grad=True).flatten().repeat(batch_size, 1)
        if torch.cuda.is_available():
            matrix = matrix.cuda()
        # add to last fc layer
        matrix = x + matrix
        # reshape to batch_size x 3 x 3
        matrix = matrix.view(batch_size, self.k, self.k)
        
        return matrix

Now we can take the TNet and place it in the greater PointNet, completing the transformations as well

In [16]:
class PointNet(nn.Module):
    def __init__(self, num_classes = 5):
        super().__init__()

        self.num_classes = num_classes

        # Input Transform TNet
        self.input_tnet = TNet(3)

        # 1st shared mlp between input and feature transform steps
        self.conv1 = nn.Conv1d(3, 64, 1)
        self.batchnorm1 = nn.BatchNorm1d(64)
        # relu called directly

        # Feature Trasformation TNet
        self.feature_tnet = TNet(64)

        # 2nd shared mlp, 1st convolution
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.batchnorm2 = nn.BatchNorm1d(128)
        # relu called directly

        # 2nd shared mlp, 2nd convolution
        self.conv3 = nn.Conv1d(128, 1024, 1)
        self.batchnorm3 = nn.BatchNorm1d(1024)
        # relu called directly


        # segmentation network, shared mlp, convolution 1
        self.conv4 = nn.Conv1d(1088, 512, 1)
        self.batchnorm4 = nn.BatchNorm1d(512)
        # relu called directly

        # segmentation network, shared mlp, convolution 2
        self.conv5 = nn.Conv1d(512, 256, 1)
        self.batchnorm5 = nn.BatchNorm1d(256)
        # relu called directly

        # segmentation network, shared mlp, convolution 3
        self.conv6 = nn.Conv1d(256, 128, 1)
        self.batchnorm6 = nn.BatchNorm1d(128)
        # relu called directly

        # segmentation network, shared mlp, convolution 4
        self.conv7 = nn.Conv1d(128, self.num_classes, 1)
        # no relu or batchnorm bc this will be the output (after logsoftmax)

        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, input):
        num_pts = input.size(-1)

        # input transformation
        mat3x3 = self.input_tnet(input)

        x = input.transpose(2, 1)       # align dims
        x = torch.bmm(x, mat3x3)
        x = x.transpose(2, 1)           # put channels back in correct spot

        # 1st shared mlp between input and feature transform steps
        x = F.relu(self.batchnorm1(self.conv1(x)))

        # feature transformation
        mat64x64 = self.feature_tnet(x)
        x = x.transpose(2, 1)
        x = torch.bmm(x, mat64x64)
        x = x.transpose(2, 1)

        # save for segmentation network later
        feature_matrix = x

        # 2nd shared mlp convolutions
        x = F.relu(self.batchnorm2(self.conv2(x)))
        x = F.relu(self.batchnorm3(self.conv3(x)))

        # Max Pool for symmmetric func / perm invariance
        x = nn.MaxPool1d(x.size(-1))(x)
        # repeat to later join with feature matrix
        x = x.repeat(1, 1, num_pts)
        # begin segmentation network section
        seg_x = torch.cat((feature_matrix, x), 1)

        # segmentation network, shared mlp convolutions
        seg_x = F.relu(self.batchnorm4(self.conv4(seg_x)))
        seg_x = F.relu(self.batchnorm5(self.conv5(seg_x)))
        seg_x = F.relu(self.batchnorm6(self.conv6(seg_x)))

        # this is our pointwise class pred
        seg_x = self.conv7(seg_x)
        seg_x = seg_x.transpose(2,1)
        seg_x = self.logsoftmax(seg_x)

        return seg_x, mat64x64
        

## Training Functions

For wandb visualization

In [17]:
def get_pred(pnet, pts, num_classes=5):
    pnet = pnet.eval()

    # prep data
    pts = pts.view(1, pts.size(0), pts.size(1))
    pts = pts.transpose(2, 1)

    pnet = pnet.eval()

    # get predictions
    pred, _ = pnet(pts)
    pred = pred.view(-1, num_classes)

    return pred.data.max(1)[1]

def get_pred_cloud(pnet, targ, num_classes=5):
    pts, _ = targ

    if torch.cuda.is_available():
        pts = pts.cuda()

    pred = get_pred(pnet, pts).view(-1, 1)

    return torch.hstack((pts, pred))

For train loop

In [18]:
def regularize(m64):
    # paper gives regularization of form ||I - AA^T||^2, where A is the feature transformed matrix (64x64)
    I = torch.flatten(torch.eye(m64.size(1)), start_dim=1)
    if torch.cuda.is_available():
        I = I.cuda()
    loss = torch.mean(torch.norm(I - torch.bmm(m64, m64.transpose(2,1)), dim=(1,2)))
    return loss

def train_step(pnet, train_dl, optimizer, epoch = 0, reg_strength = 0.001, num_classes = 5, print_batch_metircs=False):

    # train in batches
    batch = 0
    train_accuracy = 0
    train_loss = 0
    for data in iter(train_dl):

        pnet = pnet.train()

        batch += 1

        # prep data
        pts, label = data
        pts = pts.transpose(2, 1)
        if torch.cuda.is_available():
            pts, label = pts.cuda(), label.cuda()

        # zero out grads
        optimizer.zero_grad()

        # get predictions
        pred, m64 = pnet(pts)
        pred = pred.view(-1, num_classes)
        label = torch.flatten(label)

        # calc loss
        loss = F.nll_loss(pred, label) + reg_strength * regularize(m64)
        loss.backward()
        train_loss += loss

        # descent step
        optimizer.step()

        # calc accuracy of pred
        pred_segmentation = pred.data.max(1)[1]
        correct = pred_segmentation.eq(label).cpu().sum()
        accuracy = correct.item()/label.size(0)
        train_accuracy += accuracy

        if print_batch_metircs:
            print(f'epoch: {epoch}\tbatch: {batch}/{len(train_dl)}\ttrain_acc: {accuracy}\ttrain_loss: {loss.item()}')

    train_accuracy = train_accuracy / len(train_dl)
    train_loss = train_loss / len(train_dl)
    print(f'epoch: {epoch}\ttrain_acc: {train_accuracy}\ttrain_loss: {train_loss}')
    
    return train_accuracy, train_loss


def val_step(pnet, val_dl, reg_strength = 0.001, num_classes = 5):
    with torch.no_grad():
        val_accuracy = 0
        val_loss = 0
        for val_data in iter(val_dl):

            pnet = pnet.eval()

            # prep data
            pts, label = val_data
            pts = pts.transpose(2, 1)
            if torch.cuda.is_available():
                pts, label = pts.cuda(), label.cuda()

            # get predictions
            pred, _ = pnet(pts)
            pred = pred.view(-1, num_classes)
            label = torch.flatten(label)

            # calc loss
            loss = F.nll_loss(pred, label) #+ reg_strength * regularize(m64)
            val_loss += loss

            # calc accuracy of pred
            pred_segmentation = pred.data.max(1)[1]
            correct = pred_segmentation.eq(label).cpu().sum()
            accuracy = correct.item()/label.size(0)
            val_accuracy += accuracy

            del pts
            del pred
            del label

        val_accuracy = val_accuracy / len(val_dl)
        val_loss = val_loss / len(val_dl)
        print(f'val_acc: {val_accuracy}\tval_loss: {val_loss}')

        return val_accuracy / len(val_dl), val_loss / len(val_dl)

def train(train_dl, val_dl, vis_targ, epochs=10, acc_req=0.9, lr=0.001, reg_strength=0.001, num_classes=5, run_val_every=10, wandb_logs=False, wandb_vis=False, print_batch_metircs=False):

    # get device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # init pointnet
    pnet = PointNet(num_classes=num_classes)
    pnet.to(device)
    if torch.cuda.is_available():
        pnet.cuda()

    # using adam for faster convergence
    optimizer = torch.optim.Adam(pnet.parameters(), lr=lr)

    # run train loop
    for epoch in range(epochs):

        # train in batches
        train_accuracy, train_loss = train_step(pnet.train(), train_dl, optimizer, epoch=epoch, reg_strength=reg_strength, num_classes=num_classes, print_batch_metircs=print_batch_metircs)

        wandb_log = dict()

        # validation + wandb val visualizing
        if epoch % run_val_every == 0 and run_val_every > 0:
            val_accuracy, val_loss = val_step(pnet, val_dl, reg_strength=reg_strength, num_classes=num_classes)

            if wandb_logs:
                # log val metrics to wandb
                wandb_log['Val Accuracy'] = val_accuracy
                wandb_log['Val Loss'] = val_loss

                if wandb_vis:
                    # generate visualization example for wandb
                    vis_pc = get_pred_cloud(pnet, vis_targ).cpu().numpy()
                    wandb_log['generated_samples'] = wandb.Object3D(vis_pc)


        # logging to wandb
        if wandb_logs:
            wandb_log['Train Accuracy'] = train_accuracy
            wandb_log['Train Loss'] = train_loss
            wandb.log(wandb_log)

        # break if req accuracy reached
        if (train_accuracy > acc_req):
            break

    return pnet

## Train Time >:)

In [19]:
# I've included some options for code debugging
RUN_VAL_EVERY = 1
WANDB_LOGS = False
WANDB_VIS = False
PRINT_BATCH_METRICS = False

In [20]:
# per https://github.com/charlesq34/pointnet/issues/26, batch size 4 works for 4gb vram
# my pc has 8gb and can handle batch size 16 with only 5gb usage
# colab t4s have 15-16gb vram, so if running on colab, it'll likely handle 32, probably more
BATCH_SIZE = 16

EPOCHS = 15
ACC_REQ = 0.9
LR = 0.001
REG_STRENGTH = 0.001

In [21]:
if WANDB_LOGS:
  run = wandb.init(project='Pointnet Chair Part Segmentation')

In [22]:
from torch.utils.data import random_split
train_subset, val_subset = random_split(train_data, [0.7, 0.3])

In [23]:
train_dl = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_train)
val_dl = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_train)
test_dl = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_test)

In [24]:
trained_pnet = train(train_dl, val_dl, val_subset[0], epochs=EPOCHS, acc_req=ACC_REQ, lr=LR, reg_strength=REG_STRENGTH, run_val_every=RUN_VAL_EVERY, wandb_logs=WANDB_LOGS, wandb_vis=WANDB_VIS, print_batch_metircs=PRINT_BATCH_METRICS)

epoch: 0	train_acc: 0.7937704325396837	train_loss: 7.366321563720703
val_acc: 0.4431013742694041	val_loss: 7.8893561363220215
epoch: 1	train_acc: 0.8626900357440118	train_loss: 7.129920959472656
val_acc: 0.821409492997895	val_loss: 7.199841022491455
epoch: 2	train_acc: 0.87362236581174	train_loss: 7.089442253112793
val_acc: 0.8081498716836899	val_loss: 7.176362037658691
epoch: 3	train_acc: 0.8802996905463357	train_loss: 7.066039085388184
val_acc: 0.8124326198123211	val_loss: 7.186957836151123
epoch: 4	train_acc: 0.880223314650327	train_loss: 7.064637660980225
val_acc: 0.8361413893426726	val_loss: 7.117361068725586
epoch: 5	train_acc: 0.8806899299191098	train_loss: 7.058892250061035
val_acc: 0.870220960779453	val_loss: 7.071412086486816
epoch: 6	train_acc: 0.885535906450983	train_loss: 7.053778648376465
val_acc: 0.8614667528540417	val_loss: 7.068543910980225
epoch: 7	train_acc: 0.8848242355677877	train_loss: 7.045814037322998
val_acc: 0.8491467345056214	val_loss: 7.0641021728515625
epoc

## Testing

In [32]:
def visualize_test(i, pnet, test_data):
    pnet = pnet.eval()

    pts = test_data[i]

    if torch.cuda.is_available():
        pts = pts.cuda()

    pred = get_pred(pnet, pts).view(-1, 1)

    pts_numpy = pts.cpu().numpy()
    pred_numpy = pred.cpu().numpy()

    # receiving odd errors regarding 
    np.savetxt(f'pts_{i}.txt', pts_numpy, delimiter=' ')
    np.savetxt(f'pred_{i}.txt', pred_numpy, delimiter=' ')

    pts = np.loadtxt(f'pts_{i}.txt', delimiter=' ')
    pred = np.loadtxt(f'pred_{i}.txt', delimiter=' ', dtype=np.int8)

    visualize(pts, pred)

In [33]:
visualize_test(0, trained_pnet, test_data)

Passed numpy


In [34]:
visualize_test(1, trained_pnet, test_data)

Passed numpy


In [35]:
visualize_test(2, trained_pnet, test_data)

Passed numpy


In [36]:
visualize_test(3, trained_pnet, test_data)

Passed numpy


In [37]:
visualize_test(4, trained_pnet, test_data)

Passed numpy


In [38]:
visualize_test(5, trained_pnet, test_data)

Passed numpy
