## Prepare imports

In [38]:
import torch
import numpy as np 
from datasets.topological import DataModule, DataModuleConfig
import matplotlib.pyplot as plt
import pyvista as pv
from torch_geometric.data import Batch
import pyvista as pv

from models.encoder import BaseModel as Encoder

from load_configs import load_config

import matplotlib.pyplot as plt

DEVICE = "cuda:0"

from types import SimpleNamespace
from datasets import load_datamodule
import yaml
import json

from models.encoder import BaseModel as Encoder


encoder_config = load_config("./configs/config_encoder_topological.yaml")
dm = load_datamodule(encoder_config.data)


encoder = Encoder.load_from_checkpoint(f"./trained_models/{encoder_config.modelconfig.save_name}").to(DEVICE)




C:\Users\ernst\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\lightning\pytorch\utilities\migration\utils.py:56: The loaded checkpoint was produced with Lightning v2.3.3, which is newer than your current Lightning version: v2.2.3


In [39]:
encoder.layer.v.shape

torch.Size([3, 96])

## Load Models and Data

In [40]:
# batch_len = len(dm.test_dataloader())
test_ds = dm.test_ds

# For each class, grab 16 samples.
mfld_classes = test_ds.y.unique() 

test_data_list = []

for idx in mfld_classes: 
    test_data_list.extend([test_ds[test_ds.y == idx][i] for i in range(1)])


test_batch = Batch.from_data_list(test_data_list).to(DEVICE)
points_batch = test_batch.x.cpu().detach().view(-1,1024,3).numpy()


In [41]:

ect = encoder.layer(test_batch,test_batch.batch).unsqueeze(1)


with torch.no_grad():
    recon_pts = encoder(ect)
    # recon_ect_vae, theinput, z_mean, z_log_var = vae_litmodel.forward(ect)

# Undo the VAE transform
# recon_ect_vae = (recon_ect_vae + 1 ) / 2

# with torch.no_grad():
#     recon_vae_pts = ect_encoder_litmodel.model.forward(recon_ect_vae).cpu().detach().numpy()


pl = pv.Plotter(shape=(1, 4), window_size=[1600, 400],border=False,polygon_smoothing=True)


for idx in range(4):
    # points_vae = recon_vae_pts[idx].reshape(-1, 3)
    # pl.subplot(0, idx)
    # actor = pl.add_points(
    #     points_vae,
    #     style="points",
    #     emissive=False,
    #     show_scalar_bar=False,
    #     render_points_as_spheres=True,
    #     scalars=points_vae[:, 2],
    #     point_size=5,
    #     ambient=0.2, 
    #     diffuse=0.8, 
    #     specular=0.8,
    #     specular_power=40, 
    #     smooth_shading=True
    # )

    points = recon_pts[idx].reshape(-1, 3).cpu().detach().numpy()
    pl.subplot(0, idx)
    actor = pl.add_points(
        points,
        style="points",
        emissive=False,
        show_scalar_bar=False,
        render_points_as_spheres=True,
        scalars=points[:, 2],
        point_size=2,
        ambient=0.2, 
        diffuse=0.8, 
        specular=0.8,
        specular_power=40, 
        smooth_shading=True
    )



pl.background_color = "w"
pl.link_views()
pl.camera_position = "yz"
pos = pl.camera.position
pl.camera.position = (pos[0],pos[1],pos[2]+3)
pl.camera.azimuth = -45
pl.camera.elevation = 10

# create a top down light
light = pv.Light(position=(0, 0, 3), positional=True,
                cone_angle=50, exponent=20, intensity=.2)
pl.add_light(light)
pl.camera.zoom(1.3)
pl.screenshot("./figures/img/topological/reconstructed_pointcloud.png",transparent_background=True,scale=2)
pl.show()



Widget(value='<iframe src="http://localhost:50074/index.html?ui=P_0x1b4375485e0_16&reconnect=auto" class="pyvi…

# Interpolate between the ECT's 

In [42]:
START = 1
END = 3
STEPS = 10
ect = encoder.layer(test_batch,test_batch.batch).unsqueeze(1)

with torch.no_grad():
    t = torch.tensor(np.linspace(0,1,STEPS,endpoint=True)).view(-1,1,1).cuda()
    ect_interp = (t * ect[START].repeat((STEPS,1,1)) + (1-t) * ect[END].repeat((STEPS,1,1))).to(torch.float32)
    ect_interp_pc = encoder(ect_interp.unsqueeze(1))



pl = pv.Plotter(shape=(1, 10), window_size=[1600, 400],border=False,polygon_smoothing=True)

for idx in range(10):

    points = ect_interp_pc[idx].reshape(-1, 3).cpu().detach().numpy()
    pl.subplot(0, idx)
    actor = pl.add_points(
        points,
        style="points",
        emissive=False,
        show_scalar_bar=False,
        render_points_as_spheres=True,
        scalars=points[:, 2],
        point_size=6,
        ambient=0.2, 
        diffuse=0.8, 
        specular=0.8,
        specular_power=40, 
        smooth_shading=True
    )



pl.background_color = "w"
pl.link_views()
pl.camera_position = "yz"
pos = pl.camera.position
pl.camera.position = (pos[0],pos[1],pos[2]+3)
pl.camera.azimuth = -45
pl.camera.elevation = 10

# create a top down light
light = pv.Light(position=(0, 0, 3), positional=True,
                cone_angle=50, exponent=20, intensity=.2)
pl.add_light(light)
pl.camera.zoom(1.3)
pl.screenshot("./figures/img/topological/interpolate_mobius_torus_pointcloud.png",transparent_background=True,scale=2)
pl.show()


Widget(value='<iframe src="http://localhost:50074/index.html?ui=P_0x1b437433430_17&reconnect=auto" class="pyvi…

## Sample from VAE and reconstruct points 

In [43]:
samples = vae_litmodel.model.sample(64, "cuda:0")
samples = (samples + 1 ) / 2

n_images = 2

fig, axes = plt.subplots(
    nrows=1, ncols=n_images, sharex=True, sharey=True, figsize=(4,4)
)
fig.subplots_adjust(wspace=0.05,hspace=0.05)

for sample, ax in zip(samples,axes.T):
    ax.imshow(sample.cpu().detach().squeeze().numpy(),cmap="bone",vmin=-0.5,vmax=1.5)
    ax.axis("off")

plt.savefig("./figures/img/topological/generated_samples.svg",transparent=True)

NameError: name 'vae_litmodel' is not defined

In [12]:

import pyvista as pv

pl = pv.Plotter(shape=(1, 2), window_size=[400, 200],border=False,polygon_smoothing=True)

with torch.no_grad():
    batch_decoded = ect_encoder_litmodel.model.forward(samples)

batch_decoded = batch_decoded.cpu().detach().numpy()

points = batch_decoded[0].reshape(-1, 3)
pl.subplot(0,0)
actor = pl.add_points(
    points,
    style="points",
    emissive=False,
    show_scalar_bar=False,
    render_points_as_spheres=True,
    scalars=points[:, 2],
    point_size=5,
    ambient=0.2, 
    diffuse=0.8, 
    specular=0.8,
    specular_power=40, 
    smooth_shading=True
)


points = batch_decoded[1].reshape(-1, 3)
pl.subplot(0,1)
actor = pl.add_points(
    points,
    style="points",
    emissive=False,
    show_scalar_bar=False,
    render_points_as_spheres=True,
    scalars=points[:, 2],
    point_size=5,
    ambient=0.2, 
    diffuse=0.8, 
    specular=0.8,
    specular_power=40, 
    smooth_shading=True
)



pl.background_color = "w"
pl.link_views()
pl.camera_position = "yz"
pos = pl.camera.position
pl.camera.position = (pos[0],pos[1],pos[2]+3)
pl.camera.azimuth = -45
pl.camera.elevation = 10

# create a top down light
light = pv.Light(position=(0, 0, 3), positional=True,
                cone_angle=50, exponent=20, intensity=.2)
pl.add_light(light)
pl.camera.zoom(1.3)
pl.screenshot("./figures/img/topological/generated_samples.png",transparent_background=True,scale=2)
pl.show()

Widget(value='<iframe src="http://localhost:64571/index.html?ui=P_0x22242b8fc40_3&reconnect=auto" class="pyvis…