## ModelNet40 dataset
Downloads and prepares the ModelNet40 dataset.

In [None]:
from collections import defaultdict
from pathlib import Path

import kagglehub
import matplotlib.pyplot as plt
import numpy as np
from pydantic import BaseModel

from pointnet.structs import Split

In [None]:
# Download latest version
path = kagglehub.dataset_download("balraj98/modelnet40-princeton-3d-object-dataset")
dataset_dir = Path(path)

### Read in the dataset CSV

In [None]:
class DataSample(BaseModel):
    object_id: str
    class_name: str
    split: Split
    object_path: Path

In [None]:
# Find the dataset CSV file
dataset_csv = list(dataset_dir.glob("*.csv"))[0]

samples: dict[Split, list[DataSample]] = defaultdict(list)
with open(dataset_csv, "r") as fp:
    lines = fp.read().split("\n")[1:]

    for line in lines:
        elements = line.split(",")
        if len(elements) != 4:
            continue

        split = Split(elements[2])
        samples[split].append(
            DataSample(
                object_id=elements[0],
                class_name=elements[1],
                split=split,
                object_path=dataset_dir / "ModelNet40" / elements[3],
            )
        )

### Read in OFF file

In [None]:
sample = samples[Split.TRAIN][1]
with open(sample.object_path, "r") as fp:
    # Ensure first line contains "OFF"
    first_line = fp.readline().strip()
    if first_line != "OFF":
        raise RuntimeError("Invalid first line in OFF file.")

    # Read counts
    num_verts, _, _ = fp.readline().strip().split(" ")
    num_verts = int(num_verts)

    points = np.zeros((3, num_verts), dtype=np.float32)
    for idx in range(num_verts):
        points[:, idx] = [float(val) for val in fp.readline().strip().split(" ")]

In [None]:
%matplotlib widget

fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
ax.scatter(points[1, :], points[0, :], points[2, :])  # type: ignore
ax.set_title(f"Number: {sample.class_name}")
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_zlabel("z")

In [None]:
import torch

x = torch.rand((3, 200))

indices = torch.randint(0, x.shape[1], (300,))
x[:, indices].shape

In [None]:
x = {"hello", "what", "now", "you", "goose"}
sorted(x)