In [1]:
# imports
from dataset import SHREC2022Primitives
from networks import MinkowskiFCNN
import transforms as t
from losses import CylinderLoss
import MinkowskiEngine as ME

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
from utils import minkowski_collate, create_input_batch

import geomfitty
from geomfitty import geom3d
from geomfitty import plot
from geomfitty import fit3d
# configuration & hyperparameters
path = "/home/ioannis/Desktop/programming/data/SHREC/SHREC2022/dataset"
batch_size = 1
lr=1e-3
num_epochs = 100
eval_step = 1000000000
regression_targets = 7

checkpoint_path = "/home/ioannis/Desktop/programming/phd/SHREC/SHREC2022/checkpoints"
checkpoint_file = os.path.join(checkpoint_path, "cylinder_100.pth")

In [2]:
# Dataset and dataloader
train_transforms = [t.KeepInitialPoints(),
                    t.Translate(), 
                    t.SphereNormalization(), 
                    t.Initialization(),
                    t.RandomRotate(180, 0),
                    t.RandomRotate(180, 1),
                    t.RandomRotate(180, 2),
                    t.GaussianNoise(),
                    t.GetMean()]

valid_transforms = [t.KeepInitialPoints(),
                    t.Translate(), 
                    t.SphereNormalization(),
                    t.GetMean()]

t_dataset = SHREC2022Primitives(path, 
                                train=True, 
                                valid=False, 
                                valid_split=0.2, 
                                transform=train_transforms,
                                category="cylinder")

train_loader = DataLoader(t_dataset, 
                          batch_size=batch_size, 
                          shuffle=True, 
                          collate_fn=minkowski_collate, 
                         num_workers=8)

v_dataset = SHREC2022Primitives(path, 
                                train=True, 
                                valid=True, 
                                valid_split=0.2, 
                                transform=valid_transforms,
                                category="cylinder")

valid_loader = DataLoader(v_dataset, 
                          batch_size=batch_size, 
                          shuffle=False, 
                          collate_fn=minkowski_collate, 
                          num_workers=8)

Creating a new train-validation split.
Specified split already exists. Using the existing one.


In [3]:
# network and device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

net = MinkowskiFCNN(in_channel = 3, out_channel = regression_targets).to(device)
net.load_state_dict(torch.load(checkpoint_file))
net.eval()
# to use the "transform_cylinder_outputs" method
l = CylinderLoss()

Using device: cuda:0


In [4]:
for sample in valid_loader:
    minknet_input = create_input_batch(
            sample, 
            device=device,
            quantization_size=0.05
        )

    pred = net(minknet_input)
    r, axis, vertex = l.transform_cylinder_outputs(pred, sample['trans'])
    r = r.detach().cpu().squeeze().numpy()
    axis   = axis.detach().cpu().squeeze().numpy()
    vertex = vertex.detach().cpu().squeeze().numpy()
    points = sample['initial_points'][0].numpy()
    cylinder = geom3d.Cylinder(vertex, axis, r)
    #cylinder = fit3d.cylinder_fit(points, initial_guess = initial_guess)
    plot.plot([points, cylinder])

KeyboardInterrupt: 