## Prepare imports

In [2]:
import torch
import matplotlib.pyplot as plt
from omegaconf import OmegaConf


from models.encoder import BaseModel as EctEncoder
from datasets.modelnet import DataModule, DataModuleConfig
from layers.ect import EctLayer, EctConfig
from layers.directions import generate_directions


DEVICE = "cuda:0"

config = OmegaConf.load("./configs/config_encoder_modelnet.yaml")

## Load Models and Data

In [3]:
layer = EctLayer(
    EctConfig(
        num_thetas=config.layer.ect_size,
        bump_steps=config.layer.ect_size,
        normalized=True,
        device=DEVICE,
    ),
    v=generate_directions(config.layer.ect_size,config.layer.dim, DEVICE),
)

dm = DataModule(
    DataModuleConfig(name="40")
)

# Load the encoder 

ect_encoder_litmodel = EctEncoder.load_from_checkpoint(
    f"./trained_models/ectencoder_modelnet.ckpt",
    layer=layer,
    ect_size=config.layer.ect_size,
    hidden_size=config.model.hidden_size,
    num_pts=config.model.num_pts,
    num_dims=config.model.num_dims,
    learning_rate=config.model.learning_rate,
).to(DEVICE)


Processing...
Done!


In [5]:
import pyvista as pv

data_loader = dm.val_dataloader()
for batch_idx, features in enumerate(data_loader):
    break


pl = pv.Plotter(shape=(1, 8), window_size=[1600, 200],border=False,polygon_smoothing=True,off_screen=True)


for idx in range(8):
    points = features[idx].x.reshape(-1, 3).numpy()
    pl.subplot(0, idx)
    actor = pl.add_points(
        points,
        style="points",
        emissive=False,
        show_scalar_bar=False,
        render_points_as_spheres=True,
        scalars=points[:, 2],
        point_size=5,
        ambient=0.2, 
        diffuse=0.8, 
        specular=0.8,
        specular_power=40, 
        smooth_shading=True
    )


pl.background_color = "w"
pl.link_views()
pl.camera_position = "yz"
pos = pl.camera.position
pl.camera.position = (pos[0],pos[1],pos[2]+3)
pl.camera.azimuth = -45
pl.camera.elevation = 10

# create a top down light
light = pv.Light(position=(0, 0, 3), positional=True,
                cone_angle=50, exponent=20, intensity=.2)
pl.add_light(light)
pl.camera.zoom(1.3)
# pl.screenshot("./figures/reconstructed_modelnet/reconstructed_pointcloud.png",transparent_background=True,scale=2)
pl.show()
# path = pl.generate_orbital_path(n_points=64, shift=2, factor=3.0)
# pl.open_gif("./figures/reconstructed_modelnet/orbit_cloud.gif")
# pl.orbit_on_path(path, write_frames=True)
# pl.close()


Widget(value='<iframe src="http://localhost:54646/index.html?ui=P_0x1deb90c5ea0_1&reconnect=auto" class="pyvis…

## Generate an ECT and use VAE as autoencoder to recreate the ECT

In [6]:


features.to(DEVICE)
ect = layer(features,features.batch).unsqueeze(1)


# Pass reconstruction through the point cloud decoder

with torch.no_grad():
    recon_batch = ect_encoder_litmodel.model.forward(ect)


pl = pv.Plotter(shape=(8, 8), window_size=[1600, 1600],border=False,polygon_smoothing=True,off_screen=True)

for row in range(8):
    for col in range(8):
        points = recon_batch[row*8 + col].reshape(-1, 3).detach().cpu().numpy()
        pl.subplot(row, col)
        actor = pl.add_points(
            points,
            style="points",
            emissive=False,
            show_scalar_bar=False,
            render_points_as_spheres=True,
            scalars=points[:, 2],
            point_size=5,
            ambient=0.2, 
            diffuse=0.8, 
            specular=0.8,
            specular_power=40, 
            smooth_shading=True
        )


pl.background_color = "w"
pl.link_views()
pl.camera_position = "yz"
pos = pl.camera.position
pl.camera.position = (pos[0],pos[1],pos[2]+3)
pl.camera.azimuth = -45
pl.camera.elevation = 10

# create a top down light
light = pv.Light(position=(0, 0, 3), positional=True,
                cone_angle=50, exponent=20, intensity=.2)
pl.add_light(light)
pl.camera.zoom(1.3)
# pl.screenshot("./figures/reconstructed_modelnet/reconstructed_pointcloud.png",transparent_background=True,scale=2)
pl.show()
path = pl.generate_orbital_path(n_points=64, shift=2, factor=3.0)
pl.open_gif("./figures/reconstructed_modelnet/orbit_cloud.gif")
pl.orbit_on_path(path, write_frames=True)
pl.close()


Widget(value='<iframe src="http://localhost:54646/index.html?ui=P_0x1dec3fb9210_2&reconnect=auto" class="pyvis…

In [15]:


features.to(DEVICE)
ect = layer(features,features.batch).unsqueeze(1)


# Pass reconstruction through the point cloud decoder

with torch.no_grad():
    recon_batch = ect_encoder_litmodel.model.forward(ect)


pl = pv.Plotter(shape=(2, 8), window_size=[1600,400],border=False,polygon_smoothing=True,off_screen=True)

row = 6
for col in range(8):
    points = recon_batch[row*8 + col].reshape(-1, 3).detach().cpu().numpy()
    pl.subplot(0, col)
    actor = pl.add_points(
        points,
        style="points",
        emissive=False,
        show_scalar_bar=False,
        render_points_as_spheres=True,
        scalars=points[:, 2],
        point_size=5,
        ambient=0.2, 
        diffuse=0.8, 
        specular=0.8,
        specular_power=40, 
        smooth_shading=True
    )

    true_points = features[row*8 + col].x.reshape(-1, 3).cpu().numpy()
    pl.subplot(1, col)
    actor = pl.add_points(
        true_points,
        style="points",
        emissive=False,
        show_scalar_bar=False,
        render_points_as_spheres=True,
        scalars=true_points[:, 2],
        point_size=5,
        ambient=0.2, 
        diffuse=0.8, 
        specular=0.8,
        specular_power=40, 
        smooth_shading=True
    )


pl.background_color = "w"
pl.link_views()
pl.camera_position = "yz"
pos = pl.camera.position
pl.camera.position = (pos[0],pos[1],pos[2]+3)
pl.camera.azimuth = -45
pl.camera.elevation = 10

# create a top down light
light = pv.Light(position=(0, 0, 3), positional=True,
                cone_angle=50, exponent=20, intensity=.2)
pl.add_light(light)
pl.camera.zoom(1.3)
pl.screenshot("./figures/img/modelnet/reconstructed_pointcloud.png",transparent_background=True,scale=2)
pl.show()
path = pl.generate_orbital_path(n_points=64, shift=2, factor=3.0)
pl.open_gif("./figures/img/modelnet/orbit_cloud.gif")
pl.orbit_on_path(path, write_frames=True)
pl.close()

Widget(value='<iframe src="http://localhost:54646/index.html?ui=P_0x1e132307580_11&reconnect=auto" class="pyvi…