### T-Net investigation
PointNet contains a T-Net architecture that returns a 3x3 rotation transformation
matrix. The hypothesis is that the T-Net will learn to transform arbitrarily rotated
points into the right orientation, so that the downstream networks are orientation
independent.

In [None]:
from pathlib import Path

import torch
import matplotlib.pyplot as plt
from torchvision.datasets import MNIST

from pointnet.data import MNIST3DDataset, RandomRotationTransform
from pointnet.model import ClassificationPointNet

In [None]:
# Dataset settings
DATASET_PATH = Path("")

# Model settings
DEVICE = torch.device("cpu")
MODEL_PATH = Path("")

In [None]:
# Load in a trained model
model = ClassificationPointNet.load(MODEL_PATH)
model = model.to(device=DEVICE)
model.eval()

In [None]:
# Load in the dataset
valid_mnist = MNIST(DATASET_PATH, train=False, download=True)
valid_dataset = MNIST3DDataset(
    valid_mnist,
    model.num_points,
    DEVICE,
    model.dtype,
    [RandomRotationTransform()],
)

In [None]:
%matplotlib widget
number = 5
num_samples = 6

fig, axes = plt.subplots(
    num_samples, 2, subplot_kw={"projection": "3d"}, figsize=(6, 12)
)
axes[1, 0]
idx = 0
for row in range(num_samples):
    points, label = valid_dataset[idx]
    while label.item() != number:
        idx += 1
        points, label = valid_dataset[idx]
    points = points.unsqueeze(0)
    idx += 1

    axes[row, 0].scatter(points[0, 1, :], points[0, 0, :], points[0, 2, :])
    axes[row, 0].set_title(f"Sample: {idx - 1}")
    axes[row, 0].set_xlabel("x")
    axes[row, 0].set_ylabel("y")
    axes[row, 0].set_zlabel("z")
    axes[row, 0].set_xlim(-12, 12)
    axes[row, 0].set_ylim(-12, 12)
    axes[row, 0].set_zlim(-12, 12)

    with torch.no_grad():
        input_transform = model.backbone.input_transform.forward(points)
        trans_points = torch.bmm(input_transform, points)

    axes[row, 1].scatter(
        trans_points[0, 1, :], trans_points[0, 0, :], trans_points[0, 2, :]
    )
    axes[row, 1].set_title(f"Sample: {idx - 1}")
    axes[row, 1].set_xlabel("x")
    axes[row, 1].set_ylabel("y")
    axes[row, 1].set_zlabel("z")