In [9]:
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# PointNet Model

Each datapoint consist of a set $\{x_i\}_{i=1}^k$ where $x_i = [\phi, \theta, depth, v]$. However the number $x_i$'s might differ between datapoints.

In [10]:
import torch.nn as nn


class MLP(nn.Module):

    def __init__(self, sizes):

        super(MLP, self).__init__()
        self.sizes = sizes

        self.layers = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        self.activations = nn.ReLU()

        for i in range(len(sizes) - 1):
            self.layers.append(nn.Linear(sizes[i], sizes[i+1]))

        for i in range(1, len(sizes)):
            self.batch_norms.append(nn.BatchNorm1d(sizes[i]))

    def forward(self, x):
            
        for i in range(len(self.layers)):
            x = self.layers[i](x)
            x = self.batch_norms[i](x.transpose(1, 2)).transpose(1, 2)
            x = self.activations(x)

        return x




class PointNet(nn.Module):
    def __init__(self, input_size=3, num_classes=10, dropout=0.3):
        super(PointNet, self).__init__()

        self.input_size = input_size
        self.num_classes = num_classes
        self.dropout = dropout

        self.input_transform = torch.nn.Parameter(torch.eye(self.input_size))
        self.feature_transform = torch.nn.Parameter(torch.eye(64))

        self.mlp1 = MLP([self.input_size, 64, 64])

        self.mlp2 = MLP([64, 64, 128, 1024])

        self.mlp_out = nn.Sequential(
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.Dropout(self.dropout),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, self.num_classes),
        )

    def forward(self, x):

        x = x @ self.input_transform
        x = self.mlp1(x)
        x = x @ self.feature_transform
        x = self.mlp2(x)
        x = torch.max(x, dim=-2)[0]

        return x

model = PointNet().to(device)
x = torch.randn(32, 251, 3).to(device)
assert model(x).shape
del model, x

# Load Data

In [11]:
from load_data import load_data
import numpy as np

train_data, train_labels = load_data()
test_data, _ = load_data('test')

def remove_nans(data):
    new_data = []
    for d in data:
        nan_index = np.where(d==-1)[0]
        min_nan = nan_index.min()
        new_data.append(d[:min_nan])

    return new_data

train_data = remove_nans(train_data)
test_data = remove_nans(test_data)

means = np.concatenate(train_data).mean(axis=0)
stds = np.concatenate(train_data).std(axis=0)

train_data = [(d-means)/stds for d in train_data]
test_data = [(d-means)/stds for d in test_data]

In [12]:
from torch.utils.data import Dataset, DataLoader

class PointCloudDataset(Dataset):

    def __init__(self, data, labels=None):
        self.data = data
        self.labels = labels

    def __getitem__(self, index):
        label = self.labels[index] if self.labels is not None else None
        return self.data[index], label

    def __len__(self):
        return len(self.data)
    
train_dataset = PointCloudDataset(train_data, train_labels)
test_dataset = PointCloudDataset(test_data)

def collate_fn(batch):
    lengths = [len(b[0]) for b in batch]
    max_length = max(lengths)
    paddings = [max_length - l for l in lengths]
    batch = [(np.pad(b[0], ((0, p), (0, 0))), b[1]) for b, p in zip(batch, paddings)]
    data, labels = zip(*batch)
    data = np.stack(data)
    labels = np.stack(labels)
    return torch.tensor(data).to(torch.float32), torch.tensor(labels).to(torch.int64)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

# Training

In [14]:
model = PointNet(input_size=4, num_classes=2, dropout=0.3).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
loss = nn.CrossEntropyLoss()

In [15]:
from tqdm import tqdm

for epoch in range(50):
    for data, labels in tqdm(train_loader):
        optimizer.zero_grad()
        output = model(data.to(device))
        loss_val = loss(output, labels.to(device))
        loss_val.backward()
        optimizer.step()
    scheduler.step()
    print(f'Epoch: {epoch}, Loss: {loss_val.item()}')

 18%|█▊        | 29/157 [00:20<01:28,  1.45it/s]


KeyboardInterrupt: 