In [None]:
import pytorch_lightning as lightning

import torch
import torch.nn as nn
from torch.utils.data import random_split

from point_clouds.models import PointCloudsModel, PointCloudsModelHParams, EncoderHParams, RectifierHParams
from point_clouds.datasets import FurnitureDataset

In [None]:
lightning.seed_everything(42)

In [None]:
hparams = dict(
    inputs=3,
    points=2048,
    conditions=256,

)

In [None]:
encoder_hparams = dict(
    inputs=3,
    outputs=128,
    layer_widths=[32, 64, 128],
    activation="selu",
)
encoder_hparams = EncoderHParams(**encoder_hparams)

In [None]:
rectifier_hparams = dict(
    inputs=encoder_hparams.inputs + 1 + encoder_hparams.outputs,
    outputs=encoder_hparams.inputs,
    layer_widths=[128, 128, 128],
    activation="selu",
    integrator="euler",
)
rectifier_hparams = RectifierHParams(**rectifier_hparams)

In [None]:
hparams = PointCloudsModelHParams(
    inputs=3,
    points=2048,
    conditions=128,

    encoder_hparams=encoder_hparams,
    rectifier_hparams=rectifier_hparams,

    max_epochs=1,
    batch_size=4,
)

In [None]:
train_dataset = FurnitureDataset(
    root="data/",
    shapes="all",
    split="train",
    sub_samples=2048,
    samples=32768,
    download=True,
)

train_dataset, val_dataset = random_split(train_dataset, lengths=[0.8, 0.2])

In [None]:
test_dataset = FurnitureDataset(
    root="data/",
    shapes="all",
    split="test",
    sub_samples=2048,
    samples=32768,
    download=False,
)

In [None]:
model = PointCloudsModel(hparams, train_data=train_dataset, val_data=val_dataset, test_data=test_dataset)

In [None]:
model.fit()

In [None]:
samples = model.sample((9,))

plt.scatter()