## MNIST 3D dataset

In [None]:
from pathlib import Path

import torch
import matplotlib.pyplot as plt
import numpy as np
from torchvision.datasets import MNIST

from pointnet.data import MNIST3DDataset

In [None]:
dataset_2d = MNIST(
    root=Path("/mnt/data/code/experiments/lidar/pointnet/data/"), 
    train=True, 
    download=True,
)

### Visualise point distribution

In [None]:
# Find the number of pixels per image that are white
num_white_pixels: list[int] = []
for image, label in dataset_2d:
    np_image = np.asarray(image)
    num_white_pixels.append((np_image > 127).sum())

In [None]:
plt.hist(num_white_pixels)
plt.grid()
plt.title("Distribution of white pixels per image")
plt.xlabel("Number of white pixels")
plt.ylabel("Frequency")

print("Minimum number of white pixels:", min(num_white_pixels))
print("Maximum number of white pixels:", max(num_white_pixels))

### Visualise samples

In [None]:
# Define dataset contants
num_points = 200
device = torch.device("cpu")
dtype = torch.float32

dataset = MNIST3DDataset(dataset_2d, num_points, device, dtype)

In [None]:
%matplotlib widget
points, label = dataset[2]

fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
ax.scatter(points[:, 1], points[:, 0], points[:, 2], cmap="viridis") # type: ignore
ax.set_title(f"Number: {label}")
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_zlabel("z")
ax.set_xlim(0, 24)
ax.set_ylim(0, 24)
ax.set_zlim(-12, 12)