In [None]:
# allows showing the tensorboard widget
%load_ext tensorboard
# reload imports on rerun
%load_ext autoreload

# set to 1 for cuda debugging
%set_env CUDA_LAUNCH_BLOCKING=0

In [None]:
# show the tensorboard widget
%tensorboard --logdir lightning_logs

In [None]:
# enable import reloads
%autoreload 2

import pytorch_lightning as lightning

import torch
from torch.utils.data import random_split
from torchvision.datasets.utils import download_and_extract_archive

import numpy as np
import matplotlib.pyplot as plt

import pathlib
from data import MeshDataset, SingleTensorDataset
from point_clouds import PointCloudsModule

In [None]:
plt.rc("figure", dpi=250)
plt.rc("legend", fontsize=6)

plots = pathlib.Path("plots")

plots.mkdir(parents=True, exist_ok=True)

In [None]:
accelerator = "gpu"
devices = 1

In [None]:
torch.autograd.set_grad_enabled(False)

In [None]:
url = "http://3dvision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip"
filename = pathlib.Path("ModelNet10.zip")
data_root = pathlib.Path("data")

In [None]:
# download_and_extract_archive(url=url, download_root=str(data_root), extract_root=str(data_root), filename=str(filename))

In [None]:
samples_per_mesh = 2048

In [None]:
train_data = MeshDataset(root=data_root / "ModelNet10", split="train", shapes=["chair"], samples=samples_per_mesh)
n_train = int(0.8 * len(train_data))
n_val = len(train_data) - n_train

# train_data, val_data = random_split(train_data, [n_train, n_val])

test_data = MeshDataset(root=data_root / "ModelNet10", split="test", shapes=["chair"], samples=samples_per_mesh)

In [None]:
# normalize within shape, based on train data
# train_chair = train_data[0].unsqueeze(0)
# val_chair = train_data[1].unsqueeze(0)

# mean = torch.mean(train_chair, dim=1)
# std = torch.std(train_chair, dim=1)

# train_chair = (train_chair - mean) / std
# val_chair = (val_chair - mean) / std

# train_data = SingleTensorDataset(train_chair)
# val_data = SingleTensorDataset(val_chair)

# len(train_data), len(val_data), len(test_data)

In [None]:
fig = plt.figure()
ax = fig.add_subplot(projection="3d")
points = train_data[0]
ax.scatter(points[:, 0], points[:, 1], points[:, 2], s=1, color="black", alpha=0.5)
# ax.set_axis_off()

In [None]:
hparams = dict(
    input_dim=3,
    input_points=samples_per_mesh,
    condition_dim=64,
    batch_size=1,
    sample_size=1,  # TODO
    optimizer="adam",
    learning_rate=1e-3,
    weight_decay=1e-5,
    encoder_widths=[64, 128, 256],
    rectifier_widths=[128, 128, 128],
    activation="selu",
    integrator="rk45",
    beta=0.5,
)

In [None]:
model = PointCloudsModule(train_data, val_data, test_data, **hparams)

In [None]:
trainer = lightning.Trainer(
    accelerator=accelerator,
    devices=devices,
    benchmark=True,
    max_epochs=1000,
)

with torch.autograd.enable_grad():
    model.train()
    trainer.fit(model)

model.eval()

# best_model = model.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
# best_model.eval()

# model = best_model

In [None]:
samples = model.sample(n_shapes=9, n_points=1024, steps=100)

cols = int(np.sqrt(samples.shape[0]))
rows = int(np.ceil(samples.shape[0] / cols))

fig = plt.figure(figsize=plt.figaspect(cols / rows))

for i, points in enumerate(samples):
    ax = fig.add_subplot(rows, cols, i + 1, projection="3d")
    ax.scatter(points[:, 0], points[:, 1], points[:, 2], s=1, color="black", alpha=0.5, lw=0)


In [None]:
%matplotlib widget
sample = model.sample(n_shapes=1, n_points=4096, steps=1000).squeeze()

fig = plt.figure()
ax = fig.add_subplot(projection="3d")

ax.scatter(sample[:, 0], sample[:, 1], sample[:, 2], s=1, color="black", alpha=0.5, lw=0)
plt.show()