## Prepare imports

In [18]:
from datasets import get_datasets, synsetid_to_cate
from args import get_parser
from pprint import pprint
# from metrics.evaluation_metrics import EMD_CD
# from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD
# from metrics.evaluation_metrics import compute_all_metrics
from collections import defaultdict
from models.networks import PointFlow
import os
import torch
import numpy as np
import torch.nn as nn
import pyvista as pv



################################################################################
# Import model and additional stuff
################################################################################

from omegaconf import OmegaConf
from model_wrapper import TopologicalModelVAE
from load_models import load_encoder, load_vae


################################################################################
################################################################################


def get_test_loader(args):
    _, te_dataset = get_datasets(args)
    if args.resume_dataset_mean is not None and args.resume_dataset_std is not None:
        mean = np.load(args.resume_dataset_mean)
        std = np.load(args.resume_dataset_std)
        te_dataset.renormalize(mean, std)
    loader = torch.utils.data.DataLoader(
        dataset=te_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
        drop_last=False,
    )
    return loader



In [19]:
p = get_parser()
args, unknown = p.parse_known_args()
args.cates = ["airplane"]

In [20]:
# loader = get_test_loader(args)

# for data in loader:
#     idx_b, tr_pc, te_pc = data["idx"], data["train_points"], data["test_points"]
#     break

# pl = pv.Plotter(shape=(8,8), window_size=[1600, 1600],border=False,polygon_smoothing=True)

# for row in range(8):
#     for col in range(4):
#         points = te_pc[row*col + col].reshape(-1, 3).detach().cpu().numpy()
#         pl.subplot(row, col)
#         actor = pl.add_points(
#             points,
#             style="points",
#             emissive=False,
#             show_scalar_bar=False,
#             render_points_as_spheres=True,
#             scalars=points[:, 2],
#             point_size=5,
#             ambient=0.2, 
#             diffuse=0.8, 
#             specular=0.8,
#             specular_power=40, 
#             smooth_shading=True
#         )


# pl.background_color = "w"
# pl.link_views()
# pl.camera_position = "yz"
# pos = pl.camera.position
# pl.camera.position = (pos[0],pos[1],pos[2]+3)
# pl.camera.azimuth = -45
# pl.camera.elevation = 10
# # create a top down light
# light = pv.Light(position=(0, 0, 3), positional=True,
#                 cone_angle=50, exponent=20, intensity=.2)
# pl.add_light(light)
# pl.camera.zoom(1.3)
# pl.show()

In [21]:
# Instead of PointFlow we load our own modelwrapper, that
# handles the function signatures of input and output.
ect_config = OmegaConf.load(
    f"./configs/config_encoder_shapenet_{args.cates[0]}.yaml"
)
vae_config = OmegaConf.load(f"./configs/config_vae_shapenet_{args.cates[0]}.yaml")

encoder_model = load_encoder(ect_config)
vae = load_vae(vae_config)
# vae = vae.eval()
model = TopologicalModelVAE(encoder_model, vae)


C:\Users\ernst\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\lightning\pytorch\utilities\migration\utils.py:56: The loaded checkpoint was produced with Lightning v2.3.3, which is newer than your current Lightning version: v2.2.3


  | Name                | Type             | Params
---------------------------------------------------------
0 | layer               | EctLayer         | 0     
1 | training_accuracy   | MeanSquaredError | 0     
2 | validation_accuracy | MeanSquaredError | 0     
3 | test_accuracy       | MeanSquaredError | 0     
4 | loss_fn             | MSELoss          | 0     
5 | model               | Sequential       | 39.9 M
---------------------------------------------------------
39.9 M    Trainable params
0         Non-trainable params
39.9 M    Total params
159.433   Total estimated model params size (MB)


In [22]:
print("=====GEN======")
loader = get_test_loader(args)
all_sample = []
all_ref = []
for data in loader:
    idx_b, tr_pc, te_pc = data["idx"], data["train_points"], data["test_points"]

    te_pc = te_pc.cuda() if args.gpu is None else te_pc.cuda(args.gpu)
    tr_pc = tr_pc.cuda() if args.gpu is None else tr_pc.cuda(args.gpu)
    B, N = te_pc.size(0), te_pc.size(1)
    out_pc = model.reconstruct(tr_pc, num_points=N)

    # denormalize
    m, s = data["mean"].float(), data["std"].float()
    m = m.cuda() if args.gpu is None else m.cuda(args.gpu)
    s = s.cuda() if args.gpu is None else s.cuda(args.gpu)
    out_pc = out_pc * s + m
    te_pc = te_pc * s + m


    ########################
    ## Insert
    ########################

    # For comparison, we scale the output for test point cloud to have
    # unit radius. The radius of the recon pc is scaled with the same
    # value to make sure relative distances are preserved.
    # @ErnstRoell
    # te_pc_means = te_pc.mean(axis=-2, keepdim=True)
    # te_pc = te_pc - te_pc_means
    # te_pc_norms = torch.norm(te_pc, dim=-1,keepdim=True).max(dim=-2,keepdim=True)[0]
    # te_pc = te_pc / te_pc_norms


    ########################
    ## End insert
    ########################


    all_sample.append(out_pc)
    all_ref.append(te_pc)

sample_pcs = torch.cat(all_sample, dim=0)
ref_pcs = torch.cat(all_ref, dim=0)



Total number of data:2832
Min number of points: (train)2048 (test)2048
Total number of data:405
Min number of points: (train)2048 (test)2048


In [23]:
jnt = torch.cat([sample_pcs,ref_pcs],axis=1)

jnt.shape

torch.Size([405, 4096, 3])

In [24]:
pl = pv.Plotter(shape=(3,10), window_size=[2000, 600],border=False,polygon_smoothing=True)

offset = 200
for col in range(10):
    points = ref_pcs[col+offset].reshape(-1, 3).detach().cpu().numpy()
    pl.subplot(0, col)
    actor = pl.add_points(
        points,
        style="points",
        emissive=False,
        show_scalar_bar=False,
        render_points_as_spheres=True,
        scalars=points[:, 2],
        point_size=2,
        ambient=0.2, 
        diffuse=0.8, 
        specular=0.8,
        specular_power=40, 
        smooth_shading=True
    )
    points = sample_pcs[col+offset].reshape(-1, 3).detach().cpu().numpy()
    pl.subplot(1, col)
    actor = pl.add_points(
        points,
        style="points",
        emissive=False,
        show_scalar_bar=False,
        render_points_as_spheres=True,
        scalars=points[:, 2],
        point_size=2,
        ambient=0.2, 
        diffuse=0.8, 
        specular=0.8,
        specular_power=40, 
        smooth_shading=True
    )
    points = jnt[col+offset].reshape(-1, 3).detach().cpu().numpy()
    pl.subplot(2, col)
    actor = pl.add_points(
        points,
        style="points",
        emissive=False,
        show_scalar_bar=False,
        render_points_as_spheres=True,
        scalars=points[:, 2],
        point_size=2,
        ambient=0.2, 
        diffuse=0.8, 
        specular=0.8,
        specular_power=40, 
        smooth_shading=True
    )


pl.background_color = "w"
pl.link_views()
pl.camera_position = "yz"
pos = pl.camera.position
print(pos)
# pl.camera.position = (pos[0],pos[1]+3,pos[2])
pl.camera.position = (0,20,0)
pl.camera.azimuth = 0
pl.camera.elevation = 0
# create a top down light
light = pv.Light(position=(0, 0, 0), positional=True,
                cone_angle=50, exponent=20, intensity=.2)
pl.add_light(light)
pl.camera.zoom(1.3)
# pl.screenshot(f"./img/reconstructed_vae_pointcloud{offset}.png",transparent_background=True,scale=2)
# pl.clear()
pl.show()

(1.91277996058047, 0.06046977639198303, 0.03836993873119354)


Widget(value='<iframe src="http://localhost:52812/index.html?ui=P_0x2115168f250_4&reconnect=auto" class="pyvis…

In [25]:
# pl = pv.Plotter(shape=(3,10), window_size=[1600, 600],border=False,polygon_smoothing=True)

# for offset in range(0,100,10):
#     for col in range(10):
#         points = ref_pcs[col+offset].reshape(-1, 3).detach().cpu().numpy()
#         pl.subplot(0, col)
#         actor = pl.add_points(
#             points,
#             style="points",
#             emissive=False,
#             show_scalar_bar=False,
#             render_points_as_spheres=True,
#             scalars=points[:, 2],
#             point_size=2,
#             ambient=0.2, 
#             diffuse=0.8, 
#             specular=0.8,
#             specular_power=40, 
#             smooth_shading=True
#         )
#         points = sample_pcs[col+offset].reshape(-1, 3).detach().cpu().numpy()
#         pl.subplot(1, col)
#         actor = pl.add_points(
#             points,
#             style="points",
#             emissive=False,
#             show_scalar_bar=False,
#             render_points_as_spheres=True,
#             scalars=points[:, 2],
#             point_size=2,
#             ambient=0.2, 
#             diffuse=0.8, 
#             specular=0.8,
#             specular_power=40, 
#             smooth_shading=True
#         )
#         points = jnt[col+offset].reshape(-1, 3).detach().cpu().numpy()
#         pl.subplot(2, col)
#         actor = pl.add_points(
#             points,
#             style="points",
#             emissive=False,
#             show_scalar_bar=False,
#             render_points_as_spheres=True,
#             scalars=points[:, 2],
#             point_size=2,
#             ambient=0.2, 
#             diffuse=0.8, 
#             specular=0.8,
#             specular_power=40, 
#             smooth_shading=True
#         )


#     pl.background_color = "w"
#     pl.link_views()
#     pl.camera_position = "yz"
#     pos = pl.camera.position
#     print(pos)
#     # pl.camera.position = (pos[0],pos[1]+3,pos[2])
#     pl.camera.position = (0,20,0)
#     pl.camera.azimuth = 0
#     pl.camera.elevation = 0
#     # create a top down light
#     light = pv.Light(position=(0, 0, 0), positional=True,
#                     cone_angle=50, exponent=20, intensity=.2)
#     pl.add_light(light)
#     pl.camera.zoom(1.3)
#     pl.screenshot(f"./img/reconstructed_vae_pointcloud{offset}.png",transparent_background=True,scale=2)
#     pl.clear()
#     # pl.show()

In [26]:
# pos = pl.camera.position
# print(pos)

# Reference PointClouds

In [27]:
# pl = pv.Plotter(shape=(8,8), window_size=[1600, 1600],border=False,polygon_smoothing=True)

# for row in range(8):
#     for col in range(8):
#         points = ref_pcs[row*col + col].reshape(-1, 3).detach().cpu().numpy()
#         pl.subplot(row, col)
#         actor = pl.add_points(
#             points,
#             style="points",
#             emissive=False,
#             show_scalar_bar=False,
#             render_points_as_spheres=True,
#             scalars=points[:, 2],
#             point_size=5,
#             ambient=0.2, 
#             diffuse=0.8, 
#             specular=0.8,
#             specular_power=40, 
#             smooth_shading=True
#         )


# pl.background_color = "w"
# pl.link_views()
# pl.camera_position = "yz"
# pos = pl.camera.position
# pl.camera.position = (pos[0],pos[1],pos[2]+3)
# pl.camera.azimuth = -45
# pl.camera.elevation = 10
# # create a top down light
# light = pv.Light(position=(0, 0, 3), positional=True,
#                 cone_angle=50, exponent=20, intensity=.2)
# pl.add_light(light)
# pl.camera.zoom(1.3)
# pl.show()

# Reconstructed Point Clouds

In [28]:
# pl = pv.Plotter(shape=(8,8), window_size=[1600, 1600],border=False,polygon_smoothing=True)

# for row in range(8):
#     for col in range(8):
#         points = ref_pcs[row*col + col].reshape(-1, 3).detach().cpu()
#         points = points + .01*torch.randn_like(points)
#         pl.subplot(row, col)
#         actor = pl.add_points(
#             points.numpy(),
#             style="points",
#             emissive=False,
#             show_scalar_bar=False,
#             render_points_as_spheres=True,
#             scalars=points[:, 2],
#             point_size=2,
#             ambient=0.2, 
#             diffuse=0.8, 
#             specular=0.8,
#             specular_power=40, 
#             smooth_shading=True
#         )


# pl.background_color = "w"
# pl.link_views()
# pl.camera_position = "yz"
# pos = pl.camera.position
# pl.camera.position = (pos[0],pos[1],pos[2]+3)
# pl.camera.azimuth = -45
# pl.camera.elevation = 10
# # create a top down light
# light = pv.Light(position=(0, 0, 3), positional=True,
#                 cone_angle=50, exponent=20, intensity=.2)
# pl.add_light(light)
# pl.camera.zoom(1.3)
# pl.show()

In [29]:
# samples_norm = sample_pcs.norm(dim=-1)
# ref_norm = ref_pcs.norm(dim=-1)
# print(ref_norm.shape)
# print(samples_norm.shape)


# print(ref_norm.max(axis=-1)[0])
# print(samples_norm.max(axis=-1)[0])

In [30]:
# import matplotlib.pyplot as plt
# ect = model.vae.model.sample(10, "cuda:0")
# for i in range(8):
#     plt.imshow(ect[i].squeeze().detach().cpu().numpy())
#     plt.savefig(f"ECT_{i}.png")

In [33]:

NROWS = 8


ect_samples, new_samples  = model.sample(NROWS*8,2048)

# new_samples_list = []


# with torch.no_grad():
#     for _ in range(8):
#         s = model.reconstruct(s)
#         new_samples_list.append(s)

# new_samples = torch.cat(new_samples_list)

pl = pv.Plotter(shape=(NROWS,8), window_size=[1600, NROWS*200],border=False,polygon_smoothing=True)




for row in range(NROWS):
    for col in range(8):
        points = new_samples[row*8 + col].reshape(-1, 3).detach().cpu().numpy()
        pl.subplot(row, col)
        actor = pl.add_points(
            points,
            style="points",
            emissive=False,
            show_scalar_bar=False,
            render_points_as_spheres=True,
            scalars=points[:, 2],
            point_size=2,
            ambient=0.2, 
            diffuse=0.8, 
            specular=0.8,
            specular_power=40, 
            smooth_shading=True
        )


pl.background_color = "w"
pl.link_views()
pl.camera_position = "yz"
pos = pl.camera.position
pl.camera.position = (pos[0],pos[1],pos[2]+3)
pl.camera.azimuth = -45
pl.camera.elevation = 10
# create a top down light
light = pv.Light(position=(0, 0, 3), positional=True,
                cone_angle=50, exponent=20, intensity=.2)
pl.add_light(light)
pl.camera.zoom(1.3)
pl.show()



import matplotlib.pyplot as plt

plt.imshow(ect_samples[8].squeeze().detach().cpu().numpy())
plt.savefig("MYECT.png")



Widget(value='<iframe src="http://localhost:52812/index.html?ui=P_0x2115aea8730_6&reconnect=auto" class="pyvis…

In [32]:



# NROWS = 8


# _, s  = model.sample(8,2048)

# new_samples_list = []


# with torch.no_grad():
#     for _ in range(8):
#         s = model.reconstruct(s)
#         new_samples_list.append(s)

# new_samples = torch.cat(new_samples_list)

# pl = pv.Plotter(shape=(NROWS,8), window_size=[1600, NROWS*200],border=False,polygon_smoothing=True)


# for row in range(NROWS):
#     for col in range(8):
#         points = new_samples[row*8 + col].reshape(-1, 3).detach().cpu().numpy()
#         pl.subplot(row, col)
#         actor = pl.add_points(
#             points,
#             style="points",
#             emissive=False,
#             show_scalar_bar=False,
#             render_points_as_spheres=True,
#             scalars=points[:, 2],
#             point_size=2,
#             ambient=0.2, 
#             diffuse=0.8, 
#             specular=0.8,
#             specular_power=40, 
#             smooth_shading=True
#         )


# pl.background_color = "w"
# pl.link_views()
# pl.camera_position = "yz"
# pos = pl.camera.position
# pl.camera.position = (pos[0],pos[1],pos[2]+3)
# pl.camera.azimuth = -45
# pl.camera.elevation = 10
# # create a top down light
# light = pv.Light(position=(0, 0, 3), positional=True,
#                 cone_angle=50, exponent=20, intensity=.2)
# pl.add_light(light)
# pl.camera.zoom(1.3)
# pl.show()
