# Point cloud classification

In this practical session, we will train our first neural networks for point processing.
The task we target in point cloud classification. Analagous to image classification, the objective is given a point cloud $P$ to predict the class $c$.

The notebook follows a classic setup for classification pipeline.
1. notebook setup and data preparation
2. data augmentation creation
3. metrics
4. training and validation loop
5. network definition

### Notebook and data setup

We first import the librairies needed for the practical session.

In [None]:
from sklearn.metrics import confusion_matrix
from huggingface_hub import hf_hub_download
import plotly.graph_objects as go # for visualization
import tqdm
import os
import h5py
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from scipy.spatial import KDTree
import matplotlib.pyplot as plt

Then, we download the [Modelnet40 dataset](https://modelnet.cs.princeton.edu/), which contains 9840 (resp. 2468) shapes for training (resp. validation). The shapes are divided into 40 categories.

In [None]:
hf_hub_download(repo_id="Msun/modelnet40", filename="modelnet40_ply_hdf5_2048.zip", repo_type="dataset", cache_dir=".")
!unzip ./datasets--Msun--modelnet40/snapshots/d5dc795541800feeb7a4b3bd3142729a0d2adf7a/modelnet40_ply_hdf5_2048

As now usual, we also define the function to visualize the point clouds using Plotly.

In [None]:
# display the point cloud
def point_cloud_visu(pts, cls=None):

    fig = go.Figure(
        data=[
            go.Scatter3d(
                x=pts[:,0], y=pts[:,1], z=pts[:,2],
                mode='markers',
                marker=dict(size=3,
                            color=cls,
                            colorscale='Viridis',
                            )
            )
        ],
        layout=dict(
            scene=dict(
                xaxis=dict(visible=False),
                yaxis=dict(visible=False),
                zaxis=dict(visible=False),
                aspectmode="data", #this string can be 'data', 'cube', 'auto', 'manual'
                #a custom aspectratio is defined as follows:
                aspectratio=dict(x=1, y=1, z=0.95)
            )
        )
    )
    fig.show()

Finally, the last function of the notebook setup concerns data loading.
The files are in hdf5 format. They come is in parts, are loaded separately and concatenated.

In [None]:
def get_data(rootdir, files):

  filenames = []
  for line in open(os.path.join(rootdir, files)):
      line = line.split("\n")[0]
      line = os.path.basename(line)
      filenames.append(os.path.join(rootdir, line))

  data = []
  labels = []
  for filename in filenames:
      f = h5py.File(filename, "r")
      data.append(f["data"])
      labels.append(f["label"])

  data = np.concatenate(data, axis=0)
  labels = np.concatenate(labels, axis=0)

  data = data[:,:,[0,2,1]] # for convenience we put the axis in the usual order

  return data, labels.ravel()


data, labels = get_data("modelnet40_ply_hdf5_2048", "train_files.txt")
data = torch.tensor(data, dtype=torch.float)
labels = torch.tensor(labels, dtype=torch.long)
print(f"Data (points): {data.shape} - Labels: {labels.shape}, num_labels {labels.max()+1}")
point_cloud_visu(data[0], data[0,:,2])

### Data tranforms

We will first define the trasnformations to be applied to the point cloud.

#### Random decimation
In order to use smaller point clouds (easier with cpu notebooks), we will first implement random decimation of the point clouds.
The class randomly select `num_points` from the original poirn cloud.

The `__call__` method of the class takes as input a dictionary (`data_dict`) with two entries, "points" and "labels", and returns the same dictionary with updated filed "points".

**Question:** implement the `__call__` method of the class.

In [None]:
class RandomDecimation:
  def __init__(self, num_points) -> None:
     self.n_pts = num_points

  def __call__(self, data_dict):
    # fill here
    points = data_dict["points"]
    ids = torch.randperm(points.shape[0],dtype=torch.long)[:self.n_pts]
    data_dict["points"] = points[ids]
    return data_dict

points = data[0].clone()
label = labels[0]
data_dict = {"points":points, "labels":label}
transform = RandomDecimation(512)
points_t1 = transform(data_dict)["points"] + torch.tensor([2.,0,0])
points = torch.cat([points, points_t1], dim=0)
point_cloud_visu(points, points[:,2])

#### Random rotation

A classic augmentation when training with point clouds is to apply random rotation to the point clouds.

The class `RandomRotationZ` operates random rotation around the $z$-axis.

$$P_\text{rot} = P * M$$

where $M$ is the rotation matrix.

**Question:** implement the `__call__` method of the class.

In [None]:
class RandomRotationZ:
  def __call__(self, data_dict):
    # fill here
    points = data_dict["points"]
    theta = (torch.rand((1,)) * torch.pi * 2).item()
    rot = torch.tensor(
        [[ np.cos(theta), np.sin(theta), 0],
         [-np.sin(theta), np.cos(theta), 0],
         [             0,             0, 1]], dtype=torch.float)
    data_dict["points"] = points@rot
    return data_dict

points = data[0].clone()
label = labels[0]
data_dict = {"points":points, "labels":label}
transform = RandomRotationZ()
points_t1 = transform(data_dict)["points"] + torch.tensor([2,0,0])
points_t2 = transform(data_dict)["points"] + torch.tensor([4,0,0])
points = torch.cat([points, points_t1, points_t2], dim=0)
point_cloud_visu(points, points[:,2])

#### Random scaling

As a second augmentation, we implement a random scaling which modify the scale of the point cloud randomly in a range $[\text{scale}_\text{min}, \text{scale}_\text{max}]$.
$$P_\text{scaled} = \text{scale}*P$$

**Question** create the class RandomScale, which as before as a `__init__` for intialization of the scale range and a `__call__` to operate on the data dictionary.

In [None]:
class RandomScale:
    def __init__(self, scale_min, scale_max):
      self.mini = scale_min
      self.maxi = scale_max

    def __call__(self, data_dict):
      points = data_dict["points"]
      scale_factor = torch.rand((1,)).item() * (self.maxi - self.mini) + self.mini
      data_dict["points"] = points * scale_factor
      return data_dict

points = data[0].clone()
label = labels[0]
data_dict = {"points":points, "labels":label}
transform = RandomScale(scale_min=0.1, scale_max=2.)
points_t1 = transform(data_dict)["points"] + torch.tensor([2.,0,0])
points_t2 = transform(data_dict)["points"] + torch.tensor([4.,0,0])
points = torch.cat([points, points_t1, points_t2], dim=0)
point_cloud_visu(points, points[:,2])

#### Composition

We need them to assemble the transforms that we defined into an a transformation pipeline.
To do so, we create a class `Compose` (similar in effect to the torchvision one), this class takes as contructor argument a list of transform, and the `__call__` method takes a data dictionary and iterate over the transforms and outputs the updated dictionary.

**Question:** create the class `Compose`

In [None]:
class Compose:
    def __init__(self, transform_list):
        self.t = transform_list

    def __call__(self, data_dict):
        for t in self.t:
            data_dict = t(data_dict)
        return data_dict

transform = Compose(
    [
        RandomDecimation(1024),
        RandomRotationZ(),
        RandomScale(0.9,1.1),
    ])

points = data[0].clone()
label = labels[0]
data_dict = {"points":points, "labels":label}
points_t1 = transform(data_dict)["points"] + torch.tensor([2.,0,0])
points_t2 = transform(data_dict)["points"] + torch.tensor([4.,0,0])
points = torch.cat([points, points_t1, points_t2], dim=0)
point_cloud_visu(points, points[:,2])

### Modelnet40 dataset

Now we create the `Modelnet40Dataset`, the class that will be used to load the data. It inherits from the default torch `Dataset` class.
It implements three methods:
* `__init__` which takes as argument the root dorectory of the data, the split file, and possible transformations. It loads the data;
* `__len__` returns the number of samples in the split
* `forward` takes as input an id and return the corresponding data dictionary (after going through the transforms)

**Question:** implement the `Modelnet40Dataset` class.

In [None]:
# create the dataloader
class Modelnet40Dataset(Dataset):

    def __init__(self, rootdir, files, transforms=None) -> None:
        super().__init__()
        self.rootdir = rootdir
        self.files = files
        self.transforms = transforms
        self.data, self.labels = get_data(rootdir, files)
        self.data = torch.tensor(self.data, dtype=torch.float)
        self.labels = torch.tensor(self.labels, dtype=torch.long)

    def __len__(self):
      return self.data.shape[0]

    def __getitem__(self, idx):

      points = self.data[idx]
      label = self.labels[idx]
      data_dict = {"points": points, "labels": label}

      if self.transforms is not None:
        data_dict = self.transforms(data_dict)

      return data_dict

train_transforms = Compose(
    [
        RandomDecimation(1024),
        RandomRotationZ(),
        RandomScale(0.9,1.1),
    ])

train_dataset = Modelnet40Dataset(rootdir="modelnet40_ply_hdf5_2048", files="train_files.txt", transforms=train_transforms)
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=1)

for data in train_dataloader:
  print(data["points"].shape, data["labels"].shape)
  break


### Metrics

We need to quantitatively evaluate our model.
Todo so, we will implement accuracy metrics, computed based on the confusion matrix.

**Question:** fill the function that compute the confusion matrix. It takes as argument the predictions (array of integers), a batch dictionnary (labels are in "labels") and the number of classes.

In [None]:
def conf_matrix(predictions, batch, num_classes):
  pred_labels = torch.argmax(predictions, dim=-1).cpu().numpy()
  labels = batch["labels"].cpu().numpy()
  cm = confusion_matrix(labels, pred_labels, labels=torch.arange(num_classes))
  return cm

predictions = torch.rand((100,6))
labels = torch.randint(0,6,(100,))
batch = {"labels": labels}
cm = conf_matrix(predictions, batch, 6)
plt.imshow(cm)

**Question:** fill the `global_accuracy` function that compute the accuracy given a confusion matrix. That is:

$$A = \frac{\text{TP}}{|P|}$$
where $\text{TP}$ is the number of true positive (overall)

In [None]:
def global_accuracy(cm):
  return np.diag(cm).sum() / cm.sum()

print(global_accuracy(cm))

**Question:** fill the `accuracy_per_class` function, that computes the accuracy per class (returns an array of accuracies).

In [None]:
def accuracy_per_class(cm):
  tp = np.diag(cm)
  tp_fn = cm.sum(axis=1)
  tp_fn[tp_fn==0] = 1
  return tp / tp_fn

print(accuracy_per_class(cm))

Finally we implement a function that aggregate the metrics.

**Question:** fill the `get_metrics` function. Argument is the confusion matrix, and the output is a dictionary containing:
* `confusion_matrix`: the confusion matrix
* `accuracy_per_class`: the accuracy per class (array)
* `average_accuracy`: the average accuracy per class
* `accuracy`: the global accuracy

In [None]:
def get_metrics(cm):
  acc = global_accuracy(cm)
  acc_class = accuracy_per_class(cm)
  macc = acc_class.mean()

  return {"accuracy": acc,
          "accuracy_per_class": acc_class,
          "average_accuracy": macc,
          "confusion_matrix": cm
          }

### Training and validation loops

We now enter the core of the training functions.

**Question:** implement the function `classif_loss`, that compute the cross entropy and a batch data (labels are in "labels") and returns the loss.

In [None]:
def classif_loss(predictions, batch):
  labels = batch["labels"]
  loss = torch.nn.functional.cross_entropy(predictions, labels)
  return loss

**Question:** fill the function `create_optimizer` that create an AdamW optimizer for the network passed as argument. It returns the optimizer.

In [None]:
def create_optimizer(network):
  optimizer = torch.optim.AdamW(network.parameters(), lr=1e-3)
  return optimizer

We give a `to_device` function that takes a batch and convert it to a tensor on a given device.

In [None]:
def to_device(batch, device):
  if isinstance(batch, torch.Tensor):
    return batch.to(device)
  elif isinstance(batch, list):
    batch_ = []
    for elem in batch:
      if isinstance(elem, torch.Tensor):
        elem = elem.to(device)
      batch_.append(elem)
    return batch_
  elif isinstance(batch, dict):
    batch_ = {}
    for key, elem in batch.items():
      if isinstance(elem, torch.Tensor):
        elem = elem.to(device)
      batch_[key] = elem
    return batch_
  else:
    raise ValueError("unknow batch type")

We now implement the training loop per se.
The training loop takes as argument:
* a network
* a train dataloader
* an optimizer
* a device
It does several things:
- it performs an optimization loop over the training dataloader;
- it computes the overall confusion matrix (iteratively)
- it prints the accuracy and average accuracy computed from the running confusion matrix

It retrurns the metric dictionary.

**Question:** implement the training loop

*Note:* do not forget to put the network in training mode


In [None]:
def train_loop(network, train_dataloader, optimizer, num_classes, device):
  network.train()
  train_cm = np.zeros((num_classes, num_classes), dtype=np.int64)
  train_loss = 0
  iter_count = 0

  t = tqdm.tqdm(train_dataloader, ncols=100)
  for batch in t:
    batch = to_device(batch, device)
    predictions = network(batch)
    loss = classif_loss(predictions, batch)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    cm = conf_matrix(predictions, batch, num_classes)
    train_cm += cm
    train_loss += loss.item()
    iter_count += 1

    # batch metrics
    metrics = get_metrics(train_cm)
    t.set_description_str(f"Acc: {metrics['accuracy']*100:.1f}, mAcc: {metrics['average_accuracy']*100:.1f}")

  return metrics

Same as before, but for the validation loop.
There is no backward step.

**Question:** implement the validation loop

*Note:* do not forget to put the network in evaluation mode

*Note2:* use the inference mode (or no gradient) of pytorch to prevent allocating memory for the gradients.

In [None]:
def val_loop(network, test_dataloader, num_classes, device):
  network.eval()
  test_cm = np.zeros((num_classes, num_classes), dtype=np.int64)
  t = tqdm.tqdm(test_dataloader, ncols=100)

  with torch.inference_mode():
    for batch in t:
      batch = to_device(batch, device)
      predictions = network(batch)
      cm = conf_matrix(predictions, batch, num_classes)
      test_cm += cm

      # batch metrics
      metrics = get_metrics(test_cm)
      t.set_description_str(f"Acc: {metrics['accuracy']*100:.1f}, mAcc: {metrics['average_accuracy']*100:.1f}")

  return metrics


### PointNet architecture (at least close to)

[PointNet](https://arxiv.org/abs/1612.00593) is one of the first network designed specifically for point processing.
The design of PointNet is simple (made of MLPs) with the constraint that the output of the network has to remain invariant to permutation of the points (very important with point clouds).

To do so, it uses a pooling function as a permutation invariant aggregation.
In practice, they chose a max pooling over all the points.

We are interested in the classfication version of the paper.

In this practical session, in order to stay lightweight and run on cpu, we simplfy it.

[![](https://mermaid.ink/img/pako:eNplkEGLwjAQhf9KmFMFLdi9FVaorTe7lN3bGsGhSZtCmpSYoIv1v2_EqgFzmcz73oOZuUCtGYcUGqlPtUBjyfabKuJftqvsck8Wi9XIPj9GkiVRua1md7j2MNk_2Tpk-S6O4yCYh7DwQRXAIoRZMgExkjKicCjxTCqtZafau0Ocv24WX5cHCo9hwtg0w7tUvEvlS8lluIGcQD6STaSdHZydwRx6bnrsmD_X5WakYAXvOYXUfxlv0ElLgaqrt6Kz-udP1ZBa4_gcjHatgLRBefSdGxhaXnTYGuyf6oDqV-tHf_0HWD54Lw?type=png)](https://mermaid.live/edit#pako:eNplkEGLwjAQhf9KmFMFLdi9FVaorTe7lN3bGsGhSZtCmpSYoIv1v2_EqgFzmcz73oOZuUCtGYcUGqlPtUBjyfabKuJftqvsck8Wi9XIPj9GkiVRua1md7j2MNk_2Tpk-S6O4yCYh7DwQRXAIoRZMgExkjKicCjxTCqtZafau0Ocv24WX5cHCo9hwtg0w7tUvEvlS8lluIGcQD6STaSdHZydwRx6bnrsmD_X5WakYAXvOYXUfxlv0ElLgaqrt6Kz-udP1ZBa4_gcjHatgLRBefSdGxhaXnTYGuyf6oDqV-tHf_0HWD54Lw)

The MLPs will be fixed hidden size MLPs, with three layers, batchnorm and ReLU activations (except for the classiication layer).

**Question:** fill the PointNet class.

*Note* for easy use of the Batchnorm we permute the dimensions of the points at the entry of the `forward` function.

In [None]:
class PointNet(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_classes):
        super().__init__()

        self.encoder = torch.nn.Sequential(
            torch.nn.Conv1d(in_channels, hidden_channels, 1, bias=False),
            torch.nn.BatchNorm1d(hidden_channels),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv1d(hidden_channels, hidden_channels, 1, bias=False),
            torch.nn.BatchNorm1d(hidden_channels),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv1d(hidden_channels, hidden_channels, 1, bias=False),
            torch.nn.BatchNorm1d(hidden_channels),
            torch.nn.ReLU(inplace=True),
        )

        self.class_head = torch.nn.Sequential(
            torch.nn.Linear(hidden_channels, hidden_channels, bias=False),
            torch.nn.BatchNorm1d(hidden_channels),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(hidden_channels, hidden_channels, bias=False),
            torch.nn.BatchNorm1d(hidden_channels),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(hidden_channels, out_classes, bias=False),
        )


    def forward(self, data_dict):
        points = data_dict["points"]
        x = points.permute(0,2,1)
        x = self.encoder(x)
        x = torch.max(x, dim=-1)[0]
        x = self.class_head(x)
        return x


### Training loop

Finally, we train the network for classification.

For the training augmentations, we use a downsampling (to 1024 points), a random rotation and a random scaling.

For validation, we use only random downsampling and no further augmentations.


In [None]:
# number of classes in the datasets and device
num_classes = 40
device = torch.device("cpu")
num_epochs = 5 # will train for 5 epochs, that is low, but OK timewise for a practical session

# Create the network
# network = ...
network = PointNet(3,32,40)

# Create the optimizer
# optimizer = ...
optimizer = create_optimizer(network)

# Create the transforms for training
# train_transforms = ...
train_transforms = Compose(
    [
        RandomDecimation(1024),
        RandomRotationZ(),
        RandomScale(0.9,1.1),
    ])

# create the validation transforms
# val_transforms = ...
val_transforms = Compose(
    [
        RandomDecimation(1024),
    ]
    )

# create the train dataset and val dataset
train_dataset = Modelnet40Dataset(rootdir="modelnet40_ply_hdf5_2048", files="train_files.txt", transforms=train_transforms)
val_dataset = Modelnet40Dataset(rootdir="modelnet40_ply_hdf5_2048", files="test_files.txt", transforms=val_transforms)

# create the train and val dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=1)
val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=1)

# loop over the epochs
# and train / val the network
# ...
for epoch in range(num_epochs):
  train_loop(network, train_dataloader, optimizer, num_classes, device)
  metrics = val_loop(network, val_dataloader, num_classes, device)

plt.imshow(metrics["confusion_matrix"])

### Test time augmentations

Currently, we validate without augmentation, which makes it almost deterministic (there is stochasiticity in the random downsampling).

The network has been trained with multiple augmentations, it is possible also use these augmentations at test time, and accumulate the predictions to improve the scores.

This is know as *test-time augmentations*.

**Important:** this multiple prediction scheme improves the scores, but should not be used as default for comparing the networks. If done, it should (1) be mentioned explicitely and (2) be also reported the score without augmentations.


**Question:** Implement the `TTA` class, that takes a transform argument.
It applies the transform several times the transform on the points (from the data dictionary) and store the new point clouds in a list.
This list replaces the "points" field in the data dictionary before being returned.


In [None]:
# implement a test loop with TTA
class TTA:

  def __init__(self, num_tta, transform):
    self.num_tta = num_tta
    self.transform = transform

  def __call__(self, data_dict):
    points = data_dict["points"]
    point_list = []
    for _ in range(self.num_tta):
      aug_data = self.transform({"points":points})
      point_list.append(aug_data["points"])
    points = np.stack(point_list, axis=0)
    data_dict["points"] = points
    return data_dict

**Question:** implement the test loop. The predictions should be summed for a single point cloud.

In [None]:
def test_loop(network, test_dataloader, num_classes, device):
  network.eval()
  test_cm = np.zeros((num_classes, num_classes), dtype=np.int64)
  t = tqdm.tqdm(test_dataloader, ncols=100)

  with torch.inference_mode():
    for batch in t:
      batch = to_device(batch, device)
      sh = batch["points"].shape
      batch["points"] = batch["points"].reshape(-1, sh[2], sh[3])
      predictions = network(batch)
      predictions = predictions.reshape(sh[0], sh[1], -1)
      predictions = predictions.mean(dim=1)
      cm = conf_matrix(predictions, batch, num_classes)
      test_cm += cm

      # batch metrics
      metrics = get_metrics(test_cm)
      t.set_description_str(f"Acc: {metrics['accuracy']*100:.1f}, mAcc: {metrics['average_accuracy']*100:.1f}")
  return metrics


# create the test time transforms
test_transforms = TTA(8,Compose(
    [
        RandomDecimation(1024),
        RandomRotationZ(),
        RandomScale(0.9,1.1)
    ]
    ))
test_dataset = Modelnet40Dataset(rootdir="modelnet40_ply_hdf5_2048", files="test_files.txt", transforms=test_transforms)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1)
metrics = test_loop(network, test_dataloader, num_classes, device)
plt.imshow(metrics["confusion_matrix"])

## DGCNN Dynamic Graph Convolutional Neural Network

[DGCNN](https://arxiv.org/abs/1801.07829) is one of the most used network for geometric deep learning on graphs (or small point clouds).

It relies on the computation of a local graph (KNN graph).
We will compute the indices of the neighbors of each point.
To do so, we can use a transformation.

**Question:** fill the `KNNIndices` class, that computes the K-nearest neighbors of all the points in the point cloud using a KDTree. The indices are stored in the dictionary, in the field "indices".

In [None]:
class KNNIndices:
    def __init__(self, k):
      self.k = k

    def __call__(self, data_dict):
      points = data_dict["points"]
      tree = KDTree(points)
      _, indices = tree.query(points, self.k)
      data_dict["indices"] = torch.tensor(indices, dtype=torch.long)
      return data_dict


points = data[0].clone()
label = labels[0]
data_dict = {"points":points, "labels":label}
transform = KNNIndices(32)
data_dict = transform(data_dict)
colors = np.zeros(points.shape[0], dtype=np.uint8)
colors[data_dict["indices"][0]] = 1
colors[0] = 2
point_cloud_visu(points, colors)


**Question:** implement the function `knn_gather`. The function use the `gather` function of torch, to aggregate the features of the K-nearest neighbors of each point into a tensor.

Inputs:
* `x` of shape [B c N]
* `indices` of shape [B N K], to be shaped in [B 1 N*K] and then expanded to [B C N*K]
Outputs:
* x_knn of shape [B C, N, K]

In [None]:
def knn_gather(x, indices):
    indices_ = indices.view(indices.shape[0], 1, -1).expand(-1, x.shape[1], -1) # [B C N*K]
    x_knn = torch.gather(x, dim=2, index=indices_)
    x_knn = x_knn.reshape(x_knn.shape[0], x_knn.shape[1], x.shape[2], -1)
    return x_knn # shape [B C N K]

**Question:** implement the function `features_from_graph` that computes the input the features of the edge convolution from DGCNN.
The function:
1. computes the knn features ($x_j$)
2. the edge features ($x_i - x_j$)
3. concatenate the edge features with the node features ($[x_i, x_i-x_j]$)
Output shape is [C 2*C N K]

In [None]:
def features_from_graph(x, indices):
    x_edge = knn_gather(x, indices) # [B, C, N, K]
    x_edge = x_edge - x.unsqueeze(-1)
    x = torch.cat([x_edge, x.unsqueeze(-1).repeat(1,1,1,x_edge.shape[-1])], dim=1)
    return x # shape [B, 2*C, N, K]


**Question:** fill the DGCNN convolutional module.
The forward pass:
1. computes thes features from the graph
2. applies a linear layer (a Conv2d, with kernel size 1)
3. applies a batchnorm
4. uses ReLU as activation
5. compute the max over the K values
Output is of size [B C N]

In [None]:
class DGCNNConv(torch.nn.Module):

    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.conv = torch.nn.Conv2d(2*in_channels, out_channels, 1, bias=False)
        self.bn = torch.nn.BatchNorm2d(out_channels)
        self.act = torch.nn.ReLU(inplace=True)

    def forward(self, x, indices):
        # compute the edge features
        x = features_from_graph(x, indices)
        x = self.act(self.bn(self.conv(x)))
        x = torch.max(x, dim=-1)[0]
        return x


**Question:** fill the DGCNN network and train the model. For simplicity, we only train for 5 epochs and use 2 convolution block + 1 linear layer for the classification head.

In [None]:
class DGCNN(torch.nn.Module):

    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()

        self.conv1 = DGCNNConv(in_channels, hidden_channels)
        self.conv2 = DGCNNConv(hidden_channels, hidden_channels)
        self.classifier_head = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, data_dict):
        # compute the indices
        x = data_dict["points"].permute(0,2,1)
        indices = data_dict["indices"]

        # through the encoder
        x = self.conv1(x, indices)
        x = self.conv2(x, indices)

        # through the classifier
        x = torch.max(x, dim=-1)[0]
        x = self.classifier_head(x)

        return x

num_classes = 40
network = DGCNN(3,32,40)
optimizer = create_optimizer(network)
device = torch.device("cpu")

num_epochs = 5

train_transforms = Compose(
    [
        RandomDecimation(1024),
        RandomRotationZ(),
        RandomScale(0.9,1.1),
        KNNIndices(16),
    ])

val_transforms = Compose(
    [
        RandomDecimation(1024),
        KNNIndices(16),
    ]
    )

train_dataset = Modelnet40Dataset(rootdir="modelnet40_ply_hdf5_2048", files="train_files.txt", transforms=train_transforms)
val_dataset = Modelnet40Dataset(rootdir="modelnet40_ply_hdf5_2048", files="test_files.txt", transforms=val_transforms)
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=1)
val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=1)

for epoch in range(num_epochs):
  train_loop(network, train_dataloader, optimizer, num_classes, device)
  val_loop(network, val_dataloader, num_classes, device)