## Prepare imports

In [2]:
from datasets import get_datasets, synsetid_to_cate
from args import get_parser
from pprint import pprint
# from metrics.evaluation_metrics import EMD_CD
# from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD
# from metrics.evaluation_metrics import compute_all_metrics
from collections import defaultdict
from models.networks import PointFlow
import os
import torch
import numpy as np
import torch.nn as nn
import pyvista as pv

import matplotlib.pyplot as plt

plt.close("all")


################################################################################
# Import model and additional stuff
################################################################################

from omegaconf import OmegaConf
from model_wrapper import TopologicalModelVAE
from load_models import load_encoder, load_vae


################################################################################
################################################################################


def get_test_loader(args):
    _, te_dataset = get_datasets(args)
    if args.resume_dataset_mean is not None and args.resume_dataset_std is not None:
        mean = np.load(args.resume_dataset_mean)
        std = np.load(args.resume_dataset_std)
        te_dataset.renormalize(mean, std)
    loader = torch.utils.data.DataLoader(
        dataset=te_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
        drop_last=False,
    )
    return loader



In [3]:
p = get_parser()
args, unknown = p.parse_known_args()


In [4]:
# Instead of PointFlow we load our own modelwrapper, that
# handles the function signatures of input and output.
ect_config = OmegaConf.load(
    f"./configs/config_encoder_shapenet_{args.cates[0]}.yaml"
)
vae_config = OmegaConf.load(f"./configs/config_vae_shapenet_{args.cates[0]}.yaml")

encoder_model = load_encoder(ect_config)
vae = load_vae(vae_config)
model = TopologicalModelVAE(encoder_model, vae)

model.vae.model.eval()




  | Name                | Type             | Params
---------------------------------------------------------
0 | layer               | EctLayer         | 0     
1 | training_accuracy   | MeanSquaredError | 0     
2 | validation_accuracy | MeanSquaredError | 0     
3 | test_accuracy       | MeanSquaredError | 0     
4 | loss_fn             | MSELoss          | 0     
5 | model               | Sequential       | 39.9 M
---------------------------------------------------------
39.9 M    Trainable params
0         Non-trainable params
39.9 M    Total params
159.433   Total estimated model params size (MB)


C:\Users\ernst\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\lightning\pytorch\utilities\migration\utils.py:56: The loaded checkpoint was produced with Lightning v2.3.3, which is newer than your current Lightning version: v2.2.3


VanillaVAE(
  (encoder): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (1): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (4): Se

In [5]:
# import matplotlib
# matplotlib.use('QtAgg')

model.vae.model.training

ect = model.vae.model.sample(10,"cuda:0")

plt.imshow(ect[9].detach().cpu().squeeze())
plt.show()

  plt.show()


In [6]:
print("=====GEN======")
loader = get_test_loader(args)
all_sample = []
all_ref = []
for data in loader:
    idx_b, te_pc = data["idx"], data["test_points"]
    te_pc = te_pc.cuda() if args.gpu is None else te_pc.cuda(args.gpu)
    B, N = te_pc.size(0), te_pc.size(1)
    _, out_pc = model.sample(B, N)

    # # denormalize
    # m, s = data["mean"].float(), data["std"].float()
    # m = m.cuda() if args.gpu is None else m.cuda(args.gpu)
    # s = s.cuda() if args.gpu is None else s.cuda(args.gpu)
    # out_pc = out_pc * s + m
    # te_pc = te_pc * s + m


    ########################
    ## Insert
    ########################

    # For comparison, we scale the output for test point cloud to have
    # unit radius. The radius of the recon pc is scaled with the same
    # value to make sure relative distances are preserved.
    # @ErnstRoell
    te_pc_means = te_pc.mean(axis=-2, keepdim=True)
    te_pc = te_pc - te_pc_means
    te_pc_norms = torch.norm(te_pc, dim=-1,keepdim=True).max(dim=-2,keepdim=True)[0]
    te_pc = te_pc / te_pc_norms


    ########################
    ## End insert
    ########################


    all_sample.append(out_pc)
    all_ref.append(te_pc)

sample_pcs = torch.cat(all_sample, dim=0)
ref_pcs = torch.cat(all_ref, dim=0)



Total number of data:2832
Min number of points: (train)2048 (test)2048
Total number of data:405
Min number of points: (train)2048 (test)2048


# Reference PointClouds

In [7]:
pl = pv.Plotter(shape=(8,8), window_size=[1600, 1600],border=False,polygon_smoothing=True)

for row in range(8):
    for col in range(8):
        points = ref_pcs[row*col + col].reshape(-1, 3).detach().cpu().numpy()
        pl.subplot(row, col)
        actor = pl.add_points(
            points,
            style="points",
            emissive=False,
            show_scalar_bar=False,
            render_points_as_spheres=True,
            scalars=points[:, 2],
            point_size=5,
            ambient=0.2, 
            diffuse=0.8, 
            specular=0.8,
            specular_power=40, 
            smooth_shading=True
        )


pl.background_color = "w"
pl.link_views()
pl.camera_position = "yz"
pos = pl.camera.position
pl.camera.position = (pos[0],pos[1],pos[2]+3)
pl.camera.azimuth = -45
pl.camera.elevation = 10
# create a top down light
light = pv.Light(position=(0, 0, 3), positional=True,
                cone_angle=50, exponent=20, intensity=.2)
pl.add_light(light)
pl.camera.zoom(1.3)
pl.show()

Widget(value='<iframe src="http://localhost:64626/index.html?ui=P_0x24ee204be20_0&reconnect=auto" class="pyvis…

# Reconstructed Point Clouds

In [8]:
pl = pv.Plotter(shape=(8,8), window_size=[1600, 1600],border=False,polygon_smoothing=True)

for row in range(8):
    for col in range(8):
        points = sample_pcs[row*col + col].reshape(-1, 3).detach().cpu().numpy()
        pl.subplot(row, col)
        actor = pl.add_points(
            points,
            style="points",
            emissive=False,
            show_scalar_bar=False,
            render_points_as_spheres=True,
            scalars=points[:, 2],
            point_size=5,
            ambient=0.2, 
            diffuse=0.8, 
            specular=0.8,
            specular_power=40, 
            smooth_shading=True
        )


pl.background_color = "w"
pl.link_views()
pl.camera_position = "yz"
pos = pl.camera.position
pl.camera.position = (pos[0],pos[1],pos[2]+3)
pl.camera.azimuth = -45
pl.camera.elevation = 10
# create a top down light
light = pv.Light(position=(0, 0, 3), positional=True,
                cone_angle=50, exponent=20, intensity=.2)
pl.add_light(light)
pl.camera.zoom(1.3)
pl.show()

Widget(value='<iframe src="http://localhost:64626/index.html?ui=P_0x24f89616b90_1&reconnect=auto" class="pyvis…

In [9]:
samples_norm = sample_pcs.norm(dim=-1)
ref_norm = ref_pcs.norm(dim=-1)
print(ref_norm.shape)
print(samples_norm.shape)


print(ref_norm.max(axis=-1)[0])
print(samples_norm.max(axis=-1)[0])

torch.Size([405, 2048])
torch.Size([405, 2048])
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 

In [18]:
import torch 
import matplotlib.pyplot as plt 
from normalization import normalize

a = normalize(torch.rand(10,2048,3))


pl = pv.Plotter(shape=(8,8), window_size=[1600, 1600],border=False,polygon_smoothing=True)

points = a[1].reshape(-1, 3).detach().cpu().numpy()
pl.subplot(row, col)
actor = pl.add_points(
    points,
    style="points",
    emissive=False,
    show_scalar_bar=False,
    render_points_as_spheres=True,
    scalars=points[:, 2],
    point_size=5,
    ambient=0.2, 
    diffuse=0.8, 
    specular=0.8,
    specular_power=40, 
    smooth_shading=True
)


pl.background_color = "w"
pl.link_views()
pl.camera_position = "yz"
pos = pl.camera.position
pl.camera.position = (pos[0],pos[1],pos[2]+3)
pl.camera.azimuth = -45
pl.camera.elevation = 10
# create a top down light
light = pv.Light(position=(0, 0, 3), positional=True,
                cone_angle=50, exponent=20, intensity=.2)
pl.add_light(light)
pl.camera.zoom(1.3)
pl.show()

Widget(value='<iframe src="http://localhost:64626/index.html?ui=P_0x24f9ee21300_8&reconnect=auto" class="pyvis…