Skip to content

Commit

Permalink
Update example to match ModelNet class update (#264)
Browse files Browse the repository at this point in the history
Signed-off-by: Jean-Francois Lafleche <jlafleche@nvidia.com>
  • Loading branch information
Jean-Francois-Lafleche committed May 31, 2020
1 parent 594a596 commit 4e6e695
Show file tree
Hide file tree
Showing 5 changed files with 782 additions and 110 deletions.
149 changes: 108 additions & 41 deletions examples/Classification/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Train a Pointcloud classifier in 5 lines of code
# Train a Pointcloud classifier

> **Skill level:** _Beginner_
Expand Down Expand Up @@ -41,8 +41,9 @@ from torch.utils.data import DataLoader
import kaolin as kal
from kaolin import ClassificationEngine
from kaolin.datasets import ModelNet
from kaolin.models.PointNet import PointNetClassifier as PointNet
from kaolin.models.PointNet import PointNetClassifier
import kaolin.transforms as tfs
from tqdm import tqdm
```

## Dataloading
Expand All @@ -51,75 +52,141 @@ Kaolin provides convenience functions to load popular 3D datasets (of course, Mo

To start, we will define a few important parameters:
```python
# Data Loading parameters
modelnet_path = 'path/to/ModelNet/'
categories = ['chair', 'sofa']
num_points = 1024
device = 'cuda'
workers = 8

# Training parameters
batch_size = 12
learning_rate = 1e-3
epochs = 10
```

The `model_path` variable will hold the path to the ModelNet10 dataset. We will use the `categories` variable to specify which classes we want to learn to classify. `num_points` is the number of points we will sample from the mesh when transforming it to a pointcloud. Finally, we will use `device = 'cuda'` to tell pytorch to run everything on the GPU.
The `model_path` variable will hold the path to the ModelNet10 dataset. We will use the `categories` variable to specify which classes we want to learn to classify. `num_points` is the number of points we will sample from the mesh when transforming it to a pointcloud. Finally, we will disable multiprocessing and memory pinning if we are using CUDA for our transform operations.

```python
def to_device(inp):
inp.to(device)
return inp

transform = tfs.Compose([
to_device,
tfs.TriangleMeshToPointCloud(num_samples=num_points),
tfs.NormalizePointCloud()
])

num_workers = 0 if device == 'cuda' else workers
pin_memory = device != 'cuda'
```



This command defines a `transform` that first converts a mesh representation to a pointcloud and then _normalizes_ it to be centered at the origin, and have a standard deviation of 1. Much like images, 3D data such as pointclouds need to be normalized for better classification performance.

```python
train_loader = DataLoader(ModelNet(modelnet_path, categories=categories,
split='train', transform=transform, device=device),
batch_size=12, shuffle=True)
split='train', transform=transform),
batch_size=batch_size, shuffle=True,
num_workers=num_workers, pin_memory=pin_memory)
```

Phew, that was slightly long! But here's what it does. It creates a `DataLoader` object for the `ModelNet10` dataset. In particular, we are interested in loading only the `chair` and `sofa` categories. The `split='train'` argument indicates that we're loading the 'train' split. The `rep='pointcloud'` loads up meshes and converts them into pointclouds. The `transform=norm` applies a normalizing transform to each pointcloud. The other parameters are fairly easy to decipher.

Similarly, the test dataset can be loaded up as follows.
Similarly, the validation dataset can be loaded up as follows.

```python
val_loader = DataLoader(ModelNet(modelnet_path, categories=categories,
split='test',transform=transform, device=device),
batch_size=12)
split='test',transform=transform),
batch_size=batch_size, num_workers=num_workers,
pin_memory=pin_memory)
```

## Setting up our model, optimizer and loss criterion
```python
model = PointNetClassifier(num_classes=len(categories)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = torch.nn.CrossEntropyLoss()
```

## Training the pointcloud classifier

Now that all of the data is ready, we can train our classifier using the `ClassificationEngine` class provided by Kaolin. The following line of code will train and validate a _PointNet_ classifier, which is probably the simplest of pointcloud neural architectures.

```python
engine = ClassificationEngine(PointNet(num_classes=len(categories)),
train_loader, val_loader, device=device)
engine.fit()
for e in range(epochs):
print(f'{"":-<10}\nEpoch: {e}\n{"":-<10}')

train_loss = 0.
train_accuracy = 0.

model.train()
for batch_idx, (data, attributes) in enumerate(tqdm(train_loader)):
category = attributes['category'].to(device)
pred = model(data)
loss = criterion(pred, category.view(-1))
train_loss += loss.item()
loss.backward()
optimizer.step()
optimizer.zero_grad()

# Compute accuracy
pred_label = torch.argmax(pred, dim=1)
train_accuracy += torch.mean((pred_label == category.view(-1)).float()).item()

print('Train loss:', train_loss / len(train_loader))
print('Train accuracy:', train_accuracy / len(train_loader))

val_loss = 0.
val_accuracy = 0.

model.eval()
with torch.no_grad():
for batch_idx, (data, attributes) in enumerate(tqdm(val_loader)):
category = attributes['category'].to(device)
pred = model(data)
loss = criterion(pred, category.view(-1))
val_loss += loss.item()

# Compute accuracy
pred_label = torch.argmax(pred, dim=1)
val_accuracy += torch.mean((pred_label == category.view(-1)).float()).item()

print('Val loss:', val_loss / len(val_loader))
print('Val accuracy:', val_accuracy / len(val_loader))
```

This should display a long trail of training/validation stats that go like this:
This should display the training and validation loss and accuracy with each epoch:
```
Epoch: 0, Train loss: 0.6302577257156372, Train accuracy: 0.6666666865348816
Epoch: 0, Train loss: 0.608104020357132, Train accuracy: 0.7083333432674408
Epoch: 0, Train loss: 0.5694317619005839, Train accuracy: 0.7222222288449606
Epoch: 0, Train loss: 0.5308908596634865, Train accuracy: 0.7708333432674408
Epoch: 0, Train loss: 0.49486334919929503, Train accuracy: 0.8166666746139526
Epoch: 0, Train loss: 0.46080070237318677, Train accuracy: 0.8472222288449606
Epoch: 0, Train loss: 0.42722116623606, Train accuracy: 0.8690476247242519
Epoch: 0, Train loss: 0.3970450200140476, Train accuracy: 0.8854166716337204
Epoch: 0, Train loss: 0.36996302836471134, Train accuracy: 0.898148152563307
Epoch: 0, Train loss: 0.3460669249296188, Train accuracy: 0.9083333373069763
Epoch: 0, Train loss: 0.3246546902439811, Train accuracy: 0.9166666702790693
----------
Epoch: 0
----------
Train loss: 0.043646991286446335
Train accuracy: 0.9847328271574647
Val loss: 0.007488385620544089
Val accuracy: 1.0
----------
Epoch: 1
----------
Train loss: 0.044314810665162595
Train accuracy: 0.9872773567228826
Val loss: 0.003499787227209548
Val accuracy: 1.0
...
...
...
Epoch: 9, Val loss: 0.001074398518653652, Val accuracy: 1.0
Epoch: 9, Val loss: 0.0009598819953882614, Val accuracy: 1.0
Epoch: 9, Val loss: 0.0010726014385909366, Val accuracy: 1.0
Epoch: 9, Val loss: 0.0009777292708267023, Val accuracy: 1.0
Epoch: 9, Val loss: 0.0009104261476598671, Val accuracy: 1.0
Epoch: 9, Val loss: 0.0008428172893847938, Val accuracy: 1.0
Epoch: 9, Val loss: 0.0007834221362697592, Val accuracy: 1.0
Epoch: 9, Val loss: 0.0007336708978982643, Val accuracy: 1.0
Epoch: 9, Val loss: 0.0006904241699885461, Val accuracy: 1.0
Epoch: 9, Val loss: 0.0006549106868025025, Val accuracy: 1.0
----------
Epoch: 9
----------
Train loss: 0.017674992837741408
Train accuracy: 0.994910942689153
Val loss: 0.0015444405177517824
Val accuracy: 1.0
```

That's it, you've trained your first 3D classifier on pointcloud data using Kaolin!! Read through to find out more bells-and-whistles about the `ClassificationEngine` and how you can configure it.
Expand All @@ -134,11 +201,13 @@ We will create a new dataloder which will load the same data as our previous val

```python
test_loader = DataLoader(ModelNet(modelnet_path, categories=categories,
split='test',transform=transform, device=device),
split='test',transform=transform),
shuffle=True, batch_size=15)

test_batch, labels = next(iter(test_loader))
preds = engine.model(test_batch)
data, attr = next(iter(test_loader))
data = data.to('cuda')
labels = attr['category'].to('cuda')
preds = model(data)
pred_labels = torch.max(preds, axis=1)[1]
```

Expand All @@ -152,18 +221,16 @@ visualize_batch(test_batch, pred_labels, labels, categories)
<p align="center">
<img src="../../assets/classification_vis.png">
</p>
Looks like everything is green!
Looks like we have a working classifier!

## Bells and whistles

The `ClassificationEngine` can be customized to suit your needs.

You can train on other categories by simply changing the `categories` argument passed to the `ModelNet10` dataset object. For example, you can add a `bed` class by running
```python
dataset = ModelNet('/path/to/ModelNet10', categories=['chair', 'sofa', 'bed'],
split='train', rep='pointcloud', transform=norm, device='cuda:0')
```

You can also configure the parameters of the `PointNet` to your liking. For a more detailed explanation, refer to the documentation of the `PointNetClassifier` class.
You can also configure the parameters of the `PointNetClassifier` to your liking. For a more detailed explanation, refer to the documentation of the `PointNetClassifier` class.

Further, you can pass several parameters that configure the learning rate, optimizer, training duration, and more. A detailed description can be accessed from the documentation for the `ClassificationEngine` class.
Furthermore, you can experiment with different configurations of learning rate, optimizer, training duration, and more.

0 comments on commit 4e6e695

Please sign in to comment.