In [8]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
import hydra
import torch
import os
import os.path as osp
import yaml
import matplotlib.pyplot as plt
import k3d
import numpy as np
import torch
from torch.utils.data.dataloader import DataLoader
from partglot.datamodules.partglot_datamodule import PartglotDataModule
from partglot.datamodules.datasets.partglot_dataset import PartglotTestDataset
from partglot.models.pn_agnostic import PNAgnostic
from pytorch_lightning import Trainer

In [10]:
def visualize_pointcloud(point_cloud, point_size):
    plot = k3d.plot( grid_visible=False, grid=(-0.55, -0.55, -0.55, 0.55, 0.55, 0.55))
    #plt_points = k3d.points(positions=point_cloud.astype(np.float32), point_size=point_size, color=0xd0d0d0)
    plt_points = k3d.points(positions=point_cloud, point_size=point_size, color=0xd0d0d0)
    plot += plt_points
    plt_points.shader = '3d'
    plot.display()

In [121]:
def visualize_sample(idx, dataset, prediction_path="logs/pre_trained/pn_agnostic/11-02_17-07-50/pred_label/final/"):
    geos, geos_mask = dataset[idx]
    gt, sd = dataset.get_groundtruth_and_signed_distance(idx)

    ### get GT labels for all points in geos
    n_super_segs = sd.shape[-1]
    pc2sup_segs = np.argmax(sd, 1)
    gt_label_per_super_seg = np.zeros(n_super_segs)
    for i in range(n_super_segs):
        vals, counts = np.unique(gt[pc2sup_segs==i], return_counts=True)

        #find mode
        mode_value = vals[np.argmax(counts)]
        gt_label_per_super_seg[i] = mode_value #gt[pc2sup_segs==i][0]

    ### get predicted labels for all points in geos
    if prediction_path is not None:
        predicted_label_per_super_seg = np.load(prediction_path + f"{idx}_mesh_label.npy")
    
    # create figure
    plot = k3d.plot(grid_visible=False, grid=(-0.55, -0.55, -0.55, 0.55, 0.55, 0.55))

    # plot both, GT and prediction (if path is given)
    colors = [0xd0d0d0, 0xFF0000, 0x008000, 0x0000FF]
    for i in range(4):
        plt_points = k3d.points(positions=geos[geos_mask.bool()][gt_label_per_super_seg==i], point_size=0.025, color=colors[i])
        plot += plt_points

        if prediction_path is not None:
            plt_points = k3d.points(positions=torch.tensor([0,0,1])[None, None, :]+geos[geos_mask.bool()][predicted_label_per_super_seg==i], point_size=0.015, color=colors[i])
            plot += plt_points
            
    plt_points.shader = '3d'

    plot.display()

In [88]:
test_set = PartglotTestDataset({"data_dir": "data/"})

In [89]:
test_set.segs_data.shape, test_set.segs_mask.shape, len(test_set.groundtruths), len(test_set.signed_distances)

((3746, 50, 512, 3), (3746, 50), 3746, 3746)

In [122]:
visualize_sample(5, test_set)

Output()

# Re-do test.py:

In [7]:
datamodule = PartglotDataModule(batch_size=64,
        only_correct=True,
        only_easy_context=False,
        max_seq_len=33,
        only_one_part_name=True,
        seed = 12345678,
        split_sizes = [0.8, 0.1, 0.1],
        balance=True,
        data_dir="data/")

model = PNAgnostic(text_dim=64,
        embedding_dim=100,
        sup_segs_dim=64,
        lr=1e-3,
        data_dir="data/",
        word2int=datamodule.word2int,
        total_steps=1,
        measure_iou_every_epoch=True,
        save_pred_label_every_epoch=False)

ckpt = torch.load("checkpoints/pn_agnostic.ckpt")
if "state_dict" in ckpt:
    print("write state dict")
    ckpt = ckpt["state_dict"]

model.load_state_dict(ckpt)

trainer = Trainer(logger=False, gpus=[0])

trainer.test(model=model, datamodule=datamodule)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


write state dict
