In [1]:
# allows showing the tensorboard widget
%load_ext tensorboard

# allow import reloading
%load_ext autoreload

# set to 1 for cuda debugging
%set_env CUDA_LAUNCH_BLOCKING=0

env: CUDA_LAUNCH_BLOCKING=0


In [2]:
# show the tensorboard widget
%tensorboard --logdir lightning_logs

In [3]:
%autoreload 2

import pytorch_lightning as lightning

import torch
from torch.utils.data import random_split
from torchvision.datasets.utils import download_and_extract_archive

import numpy as np
import matplotlib.pyplot as plt

import pathlib
from data import MeshDataset, SingleTensorDataset
from models import PointCloudsModule

import utils
import blender_plot as bp

In [4]:
plt.rc("figure", dpi=250, titlesize=6)
plt.rc("legend", fontsize=6)

plots_path = pathlib.Path("plots")
samples_path = pathlib.Path("samples")

plots_path.mkdir(parents=True, exist_ok=True)
samples_path.mkdir(parents=True, exist_ok=True)

In [5]:
torch.autograd.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7efd517c66b0>

In [6]:
data_root = pathlib.Path("data")

In [7]:
samples_per_mesh = 2048

In [8]:
train_data = MeshDataset(root=data_root / "processed", split="train", shapes="all", samples=samples_per_mesh)
n_train = int(0.8 * len(train_data))
n_val = len(train_data) - n_train

seed = torch.seed()
torch.manual_seed(42)
train_data, val_data = random_split(train_data, [n_train, n_val])
torch.manual_seed(seed)

test_data = MeshDataset(root=data_root / "processed", split="test", shapes="all", samples=samples_per_mesh)

len(train_data), len(val_data), len(test_data)

(3192, 799, 908)

In [9]:
inputs = 3
conditions = 128

encoder_hparams = dict(
    inputs=inputs,
    points=samples_per_mesh,
    conditions=conditions,
    kind="deterministic",
    dropout=0.1,
    widths=[[], []], # reimplement, maybe
    activation="selu",
    checkpoints=True,
    Lambda=0.5,
)

rectifier_hparams = dict(
    inputs=inputs,
    conditions=conditions,
    dropout=0.1,
    widths=[256, 512, 512, 512, 256],
    activation="selu",
    checkpoints=True,
    integrator="euler",
)

hparams = dict(
    accelerator="gpu",
    devices=1,
    max_epochs=5,
    optimizer="adam",
    learning_rate=1e-3,
    weight_decay=1e-5,
    batch_size=80,
    accumulate_batches=None,
    gradient_clip=1.0,
    encoder_hparams=encoder_hparams,
    rectifier_hparams=rectifier_hparams,
    augment_noise=0.05,
    mmd_scales=torch.logspace(-2, 2, 20),
    mmd_samples=None,
    time_samples=2,
    encoder_weight=1.0,
    rectifier_weight=1.0,
    profiler="simple",
)

In [10]:
model = PointCloudsModule(train_data, val_data, test_data, **hparams)
print(model)

PointCloudsModule(
  (encoder): Encoder(
    (network): Sequential(
      (0): Linear(in_features=3, out_features=128, bias=True)
      (1): SELU()
      (2): GlobalMultimaxPool1d()
      (3): Dropout(p=0.1, inplace=False)
      (4): Linear(in_features=128, out_features=256, bias=True)
      (5): SELU()
      (6): Dropout(p=0.1, inplace=False)
      (7): Linear(in_features=256, out_features=512, bias=True)
      (8): SELU()
      (9): GlobalMultimaxPool1d()
      (10): Linear(in_features=512, out_features=512, bias=True)
      (11): SELU()
      (12): Dropout(p=0.1, inplace=False)
      (13): Linear(in_features=512, out_features=512, bias=True)
      (14): SELU()
      (15): GlobalMultimaxPool1d()
      (16): Linear(in_features=512, out_features=512, bias=True)
      (17): SELU()
      (18): Dropout(p=0.1, inplace=False)
      (19): Linear(in_features=512, out_features=512, bias=True)
      (20): SELU()
      (21): GlobalMultimaxPool1d()
      (22): Linear(in_features=512, out_features

In [None]:
trainer = model.configure_trainer()

with torch.autograd.enable_grad():
    model.train()
    trainer.fit(model)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
The following callbacks returned in `LightningModule.configure_callbacks` will override existing callbacks passed to Trainer: ModelCheckpoint
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type      | Params
----------------------------------------
0 | encoder   | Encoder   | 1.5 M 
1 | rectifier | Rectifier | 823 K 
----------------------------------------
2.3 M     Trainable params
0         Non-trainable params
2.3 M     Total params
9.338     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [None]:
# best_model = model.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
# model = best_model

In [None]:
# checkpoint = trainer.checkpoint_callback.best_model_path
checkpoint = "lightning_logs/version_5/checkpoints/last.ckpt"
model = model.load_from_checkpoint(checkpoint)

In [None]:
model.eval()

In [None]:
model.hparams

In [None]:
def plot_samples(samples):
    cols = int(np.sqrt(samples.shape[0]))
    rows = int(np.ceil(samples.shape[0] / cols))

    fig = plt.figure(figsize=(cols, rows))

    for i, points in enumerate(samples):
        ax = fig.add_subplot(rows, cols, i + 1, projection="3d")
        ax.scatter(points[:, 0], points[:, 1], points[:, 2], s=1, color="black", alpha=0.5, lw=0)
        ax.set_axis_off()
    
    # set the spacing between subplots
    plt.subplots_adjust(left=0.1, bottom=0.1, right=0.9, top=0.9, wspace=0.0, hspace=0.0)

In [None]:
samples = model.sample(n_shapes=9, n_points=2048, steps=100)
torch.save(samples, samples_path / "random_samples.pt")

plot_samples(samples)
plt.gcf().suptitle("Random Samples")
plt.savefig(plots_path / "random_samples.png")

In [None]:
samples = model.sample(n_shapes=1024, n_points=2048, steps=100)
torch.save(samples, samples_path / "many_random_samples.pt")

In [None]:
# %matplotlib widget
sample = model.sample(n_shapes=1, n_points=4096, steps=100).squeeze()
torch.save(sample, samples_path / "sample.pt")

fig = plt.figure()
ax = fig.add_subplot(projection="3d")

ax.scatter(sample[:, 0], sample[:, 1], sample[:, 2], s=1, color="black", alpha=0.5, lw=0)
plt.show()

In [None]:
# create a high-fidelity render of the above
s = bp.DefaultScene()
s.scatter(sample, alpha=0.5)
img = s.render(plots_path / "renders" / "sample.png", resolution=(1200, 1200), samples=128)
s.save(samples_path / "blendfiles" / "sample.blend")
img

In [None]:
# sample conditions with constant noise
samples = model.sample_shapes(n_shapes=9, n_points=2048, steps=100)
torch.save(samples, samples_path / "shapes.pt")

plot_samples(samples)
plt.gcf().suptitle("Sampling From Constant Noise With Varying Condition")
plt.savefig(plots_path / "shapes.png")

In [None]:
# sample noise with constant condition
samples = model.sample_variations(n_shapes=9, n_points=2048, steps=100)
torch.save(samples, samples_path / "variations.pt")

plot_samples(samples)
plt.gcf().suptitle("Sampling from Constant Condition with Varying Noise")
plt.savefig(plots_path / "variations.png")

In [None]:
rows = 8
cols = 6
points = 2048
steps = 100

conditions = model.encoder.distribution.sample((cols,)).to(model.device)
noise = model.rectifier.distribution.sample((rows, points)).to(model.device)

conditions = utils.repeat_dim(conditions, rows, dim=0)
noise = noise.repeat_interleave(cols, dim=0)

points, _time = model.rectifier.inverse(noise, condition=conditions, steps=steps)

plot_samples(points)
fig = plt.gcf()
fig.suptitle("Sampling with Varying Noise or Condition")
fig.supxlabel("Condition")
fig.supylabel("Noise")
plt.savefig(plots_path / "shape_variations.png")

In [None]:
# reconstruct random train samples
random_indices = torch.randperm(len(train_data))[:9]
random_samples = torch.stack([train_data[i] for i in random_indices])

reconstructions = model.reconstruct(random_samples, steps=100)
torch.save(reconstructions, samples_path / "reconstructions.pt")

plot_samples(reconstructions)
plt.gcf().suptitle("Reconstructions")
plt.savefig(plots_path / "reconstructions.png")

In [None]:
plot_samples(random_samples)
plt.gcf().suptitle("Train Data")
plt.savefig(plots_path / "train.png")