## Prepare imports

In [1]:
from datasets.topological import DataModule, DataModuleConfig
import torch
import matplotlib.pyplot as plt
from omegaconf import OmegaConf

from models.vae import VanillaVAE
from models.vae import BaseModel as BaseVAE
from models.encoder import BaseModel as EctEncoder
from layers.ect import EctLayer, EctConfig

from metrics.metrics import get_mse_metrics
from metrics.accuracies import compute_mse_accuracies
from metrics.loss import compute_mse_loss_fn
from layers.directions import generate_directions


DEVICE = "cuda:0"

config = OmegaConf.load("./configs/config_encoder_topological.yaml")

dm = DataModule(DataModuleConfig())

print(len(dm.test_ds))

# for batch in dm.test_dataloader():
#     print(batch.y)

4000


## Load Models and Data

In [2]:


layer = EctLayer(
    EctConfig(
        num_thetas=config.layer.ect_size,
        bump_steps=config.layer.ect_size,
        normalized=True,
        device=DEVICE,
    ),
    v=generate_directions(config.layer.ect_size,config.layer.dim, DEVICE),
)

# Load the encoder 

ect_encoder_litmodel = EctEncoder.load_from_checkpoint(
    f"./trained_models/ectencoder_topological.ckpt",
    layer=layer,
    ect_size=config.layer.ect_size,
    hidden_size=config.model.hidden_size,
    num_pts=config.model.num_pts,
    num_dims=config.model.num_dims,
    learning_rate=config.model.learning_rate,
).to(DEVICE)


## Generate an ECT and use VAE as autoencoder to recreate the ECT

In [3]:
len(dm.test_ds)

4000

In [4]:
import numpy as np
from torch_geometric.data import Batch
# idxs = np.random.choice(list(range(4000)),64).tolist()
idxs = np.hstack([
    np.arange(0,16,1),
    np.arange(1000,1016,1),
    np.arange(2000,2016,1),
    np.arange(3000,3016,1),
])

print(idxs)

[   0    1    2    3    4    5    6    7    8    9   10   11   12   13
   14   15 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011
 1012 1013 1014 1015 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009
 2010 2011 2012 2013 2014 2015 3000 3001 3002 3003 3004 3005 3006 3007
 3008 3009 3010 3011 3012 3013 3014 3015]


In [5]:

features = Batch.from_data_list([dm.test_ds[el] for el in idxs])

print(features)
features.to(DEVICE)
ect = layer(features,features.batch).unsqueeze(1)

with torch.no_grad():
    decoded = ect_encoder_litmodel.model.forward(ect)

import pyvista as pv

pl = pv.Plotter(shape=(8, 8), window_size=[1600, 1600],border=False,polygon_smoothing=True)

batch = features.x.cpu().detach().view(-1,1024,3).numpy()



DataBatch(x=[65536, 3], y=[64], batch=[65536], ptr=[65])


In [6]:
# for row in range(8):
#     for col in range(8):
#         points = batch[8*row + col].reshape(-1, 3)
#         pl.subplot(row, col)
#         actor = pl.add_points(
#             points,
#             style="points",
#             emissive=False,
#             show_scalar_bar=False,
#             render_points_as_spheres=True,
#             scalars=points[:, 2],
#             point_size=5,
#             ambient=0.2, 
#             diffuse=0.8, 
#             specular=0.8,
#             specular_power=40, 
#             smooth_shading=True
#         )


# pl.background_color = "w"
# pl.link_views()
# pl.camera_position = "yz"
# pos = pl.camera.position
# pl.camera.position = (pos[0],pos[1],pos[2]+3)
# pl.camera.azimuth = -45
# pl.camera.elevation = 10

# # create a top down light
# light = pv.Light(position=(0, 0, 3), positional=True,
#                 cone_angle=50, exponent=20, intensity=.2)
# pl.add_light(light)
# pl.camera.zoom(1.3)
# # pl.screenshot("./figures/img/topological/reconstructed_pointcloud.png",transparent_background=True,scale=2)
# pl.show()
# # path = pl.generate_orbital_path(n_points=64, shift=2, factor=3.0)
# # pl.open_gif("./figures/img/topological/reconstructed_pointcloud_full.gif")
# # pl.orbit_on_path(path, write_frames=True)
# # pl.close()


In [7]:
import pyvista as pv

pl = pv.Plotter(shape=(8, 8), window_size=[1600, 1600],border=False,polygon_smoothing=True,off_screen=True)

batch = decoded.cpu().detach().numpy()


for row in range(8):
    for col in range(8):
        points = batch[8*row + col].reshape(-1, 3)
        pl.subplot(row, col)
        actor = pl.add_points(
            points,
            style="points",
            emissive=False,
            show_scalar_bar=False,
            render_points_as_spheres=True,
            scalars=points[:, 2],
            point_size=5,
            ambient=0.2, 
            diffuse=0.8, 
            specular=0.8,
            specular_power=40, 
            smooth_shading=True
        )


pl.background_color = "w"
pl.link_views()
pl.camera_position = "yz"
pos = pl.camera.position
pl.camera.position = (pos[0],pos[1],pos[2]+3)
pl.camera.azimuth = -45
pl.camera.elevation = 10

# create a top down light
light = pv.Light(position=(0, 0, 3), positional=True,
                cone_angle=50, exponent=20, intensity=.2)
pl.add_light(light)
pl.camera.zoom(1.3)
pl.screenshot("./figures/img/topological/reconstructed_pointcloud.png",transparent_background=True,scale=2)
pl.show()
path = pl.generate_orbital_path(n_points=64, shift=2, factor=3.0)
pl.open_gif("./figures/img/topological/reconstructed_pointcloud.gif")
pl.orbit_on_path(path, write_frames=True)
pl.close()


Widget(value='<iframe src="http://localhost:57647/index.html?ui=P_0x22af6ab8880_1&reconnect=auto" class="pyvis…