In [1]:
#imports
import torch
from torch.utils.data import DataLoader
import MinkowskiEngine as ME
from tqdm import tqdm
import os
import time
import numpy as np

import transforms as t
from dataset import SHREC2022Primitives
from networks import MinkowskiFCNN
from utils import minkowski_collate, create_input_batch 
import losses

import geomfitty
import cone_fit


# configuring paths for data and checkpoints
#data_path = "/home/beastmaster/Desktop/Eleftheria/shrec/dataset/" 
data_path = "/home/ioannis/Desktop/programming/data/SHREC/SHREC2022/dataset"
#checkpoint_path = "/home/beastmaster/Desktop/vlassisgiannis/SHREC2022_PrimitiveRecognition-regressors/checkpoints"
checkpoint_path = "/home/ioannis/Desktop/programming/phd/SHREC/SHREC2022/checkpoints"

cls_checkpoint = os.path.join(checkpoint_path, "classification.pth")
plane_checkpoint = os.path.join(checkpoint_path,"plane.pth")
sphere_checkpoint = os.path.join(checkpoint_path,"sphere.pth")
cylinder_checkpoint = os.path.join(checkpoint_path,"cylinder.pth")
cone_checkpoint = os.path.join(checkpoint_path,"cone.pth")
torus_checkpoint = os.path.join(checkpoint_path,"torus.pth")

# Initializing transforms
# (will use the same augmentation for all networks and tasks)
train_transforms = [t.KeepInitialPoints(),
                    t.Translate(), 
                    t.SphereNormalization(), 
                    t.Initialization(),
                    t.RandomRotate(180, 0),
                    t.RandomRotate(180, 1),
                    t.RandomRotate(180, 2),
                    t.GaussianNoise(),
                    t.GetMean()]

valid_transforms = [t.KeepInitialPoints(),
                    t.Translate(), 
                    t.SphereNormalization(),
                    t.GetMean()]


# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cpu_device = torch.device("cpu")
print(f"Using Device: {device}")

# disabling gradient tracking for better performance
torch.set_grad_enabled(False)

Using Device: cuda


<torch.autograd.grad_mode.set_grad_enabled at 0x7f6e3960d6a0>

In [2]:
# Loading networks
cls_network = MinkowskiFCNN(in_channel=3,   # point features
                        out_channel=5,  # num classes
                        ).to(device)
cls_network.load_state_dict(torch.load(cls_checkpoint))
cls_network.eval()

plane_net = MinkowskiFCNN(in_channel=3,   # point features
                        out_channel=3,    # IMPORTANT: regression only for the plane normal 
                                          # as the point is calculated as the average 
                                          # of the point cloud
                        ).to(device)
plane_net.load_state_dict(torch.load(plane_checkpoint))
plane_net.eval()

cylinder_net = MinkowskiFCNN(in_channel=3,   # point features
                             out_channel=7,  
                                          
                             ).to(device)
cylinder_net.load_state_dict(torch.load(cylinder_checkpoint))
cylinder_net.eval()

sphere_net = MinkowskiFCNN(in_channel=3,   # point features
                           out_channel=4,      
                             ).to(device)
sphere_net.load_state_dict(torch.load(sphere_checkpoint))
sphere_net.eval()

cone_net = MinkowskiFCNN(in_channel=3,   # point features
                           out_channel=7,      
                             ).to(device)
cone_net.load_state_dict(torch.load(cone_checkpoint))
cone_net.eval()

# network 
torus_net = MinkowskiFCNN(in_channel=3,   # point features
                           out_channel=8,      
                             ).to(device)
torus_net.load_state_dict(torch.load(torus_checkpoint))
torus_net.eval()
print("Networks Initialized")

Networks Initialized


# Dataset

In [3]:
batch_size = 1
# dataset 
train_dataset = SHREC2022Primitives(data_path, 
                                train=True, 
                                valid=False, 
                                valid_split=0.0, 
                                transform=train_transforms,
                                category="all")

train_loader = DataLoader(train_dataset, 
                          batch_size=batch_size, 
                          shuffle=False, 
                          collate_fn=minkowski_collate, 
                          num_workers=8)



test_dataset = SHREC2022Primitives(data_path, 
                                  train=False,
                                  valid_split=0.0,
                                  transform=valid_transforms,
                                  category="all")

test_loader = DataLoader(test_dataset, 
                          batch_size=batch_size, 
                          shuffle=False, 
                          collate_fn=minkowski_collate, 
                          num_workers=8)


Dataset path:  /home/ioannis/Desktop/programming/data/SHREC/SHREC2022/dataset/training
Specified split already exists. Using the existing one.
Dataset path:  /home/ioannis/Desktop/programming/data/SHREC/SHREC2022/dataset/test


In [4]:
def ls_fit_shape(point_cloud, cls, params):
    
    if isinstance(params, torch.Tensor):
        #print(params)
        params = params.cpu().numpy()
    if isinstance(point_cloud, torch.Tensor):
        point_cloud = point_cloud.cpu().numpy()
        
    
    if cls == 0: # plane
        pass

        
    elif cls == 1: # cylinder
        radius = params[0]
        axis = params[1:4]
        point = params[4:7]
        
        initial_guess = geomfitty.geom3d.Cylinder(point, axis, radius)
        cylinder = geomfitty.fit3d.cylinder_fit(point_cloud, weights=None, initial_guess=initial_guess)
        if not isinstance(cylinder, geomfitty.geom3d.Cylinder):
            print("Cylinder: Least squares method failed")
        else:
            radius, axis, point = cylinder.radius, cylinder.direction, cylinder.anchor_point
            radius = np.array([radius])
            params = np.concatenate([radius, axis, point])
            
        # get the loss between the predicted and the regressed shape
        # if the loss is large, then the least square algorithm has failed to converge 
        # and we use the network output
        
    elif cls == 2: # sphere
        radius = params[0]
        center = params[1:4]
        
        initial_guess = geomfitty.geom3d.Sphere(center, radius)
        sphere = geomfitty.fit3d.sphere_fit(point_cloud, weights=None, initial_guess=initial_guess)
        if not isinstance(sphere, geomfitty.geom3d.Sphere):
            print("Sphere: Least squares method failed")
        else:
            radius, center = sphere.radius, sphere.center
            radius = np.array([radius])
            params = np.concatenate([radius, center])
        
    elif cls == 3: # cone
        theta = params[0]
        axis = params[1:4]
        vertex = params[4:7]
        initial_guess = cone_fit.Cone(theta, axis, vertex)
        cone = cone_fit.cone_fit(point_cloud, weights=None, initial_guess=initial_guess)
        if not isinstance(cone, cone_fit.Cone):
            print("Cone: Least squares method failed")
            #params = params
        else:
            theta, axis, vertex = cone.theta, cone.axis, cone.vertex
            theta = np.array([theta])
            params = np.concatenate([theta, axis, vertex])
        
    elif cls == 4: # torus
        R = params[0]
        r = params[1]
        axis = params[2:5]
        center = params[5:8]
        
        initial_guess = geomfitty.geom3d.Torus(center, axis, R, r)
        torus = geomfitty.fit3d.torus_fit(point_cloud, weights=None, initial_guess=initial_guess)
        if not isinstance(torus, geomfitty.geom3d.Torus):
            print("Torus: Least squares method failed")
        else:
            R, r, axis, center = torus.major_radius, torus.minor_radius, torus.direction, torus.center
            R = np.array([R])
            r = np.array([r])
            params = np.concatenate([R, r, axis, center])
    
    return torch.tensor(params)

In [5]:
def regress_params(shape, net_in, batch, trans):
    
    if shape == 0:
        
        normal = plane_net(net_in)
        point = batch['means'].to(device) # device is global
        params = torch.cat([normal, point], dim=-1)
        params = torch.cat(losses.PlaneLoss().transform_plane_outputs(params, trans), dim=-1)
        
    if shape == 1:
        
        params = cylinder_net(net_in)
        
        r, axis, vertex = losses.CylinderLoss().transform_cylinder_outputs(params, trans)
        r = r.unsqueeze(-1)
        params = torch.cat([r, axis, vertex], dim=-1)
    
    if shape == 2: 
        
        params = sphere_net(net_in)
        r, center = losses.SphereLoss().transform_sphere_outputs(params, trans)
        r = r.unsqueeze(-1)
        params = torch.cat([r, center], dim=-1)
        
    if shape == 3: 
        
        params = cone_net(net_in)
        theta, axis, vertex = losses.ConeLoss().transform_cone_outputs(params, trans)
        theta = theta.unsqueeze(-1)
        params = torch.cat([theta, axis, vertex], dim=-1)
        
    if shape == 4:
        
        params = torus_net(net_in)
        R, r, axis, center = losses.TorusLoss().transform_torus_outputs(params, trans)    
        R = R.unsqueeze(-1)
        r = r.unsqueeze(-1)
        params = torch.cat([R, r, axis, center], dim=-1)
    
    return params

In [6]:
def distance_points_shape(shape_type, shape_params, initial_points):
    
    if isinstance(shape_params, torch.Tensor):
        shape_params = shape_params.cpu().numpy()
    
    if isinstance(initial_points, torch.Tensor):
        initial_points = initial_points.cpu().numpy()
    
    if shape_type == 0:
        
        normal = shape_params[:3]
        vertex = shape_params[3:]
        shape = geomfitty.geom3d.Plane(normal, vertex)
        
    elif shape_type == 1:
        
        radius = shape_params[0]
        axis = shape_params[1:4]
        vertex = shape_params[4:7]
        shape = geomfitty.geom3d.Cylinder(vertex, axis, radius)
        
    elif shape_type == 2:
        
        radius = shape_params[0]
        center = shape_params[1:4]
        shape = geomfitty.geom3d.Sphere(center, radius)
        
    elif shape_type == 3:
        
        theta = shape_params[0]
        axis = shape_params[1:4]
        vertex = shape_params[4:7]
        shape = cone_fit.Cone(theta, axis, vertex)
    
    elif shape_type == 4:

        Radius = shape_params[0]
        radius = shape_params[1]
        axis = shape_params[2:5]
        center = shape_params[5:8]
        shape = geomfitty.geom3d.Torus(center, axis, Radius, radius)
    else:
        print("NOT ACCESSED")
    
    distance = shape.distance_to_point(initial_points).mean(0)

    return distance
    

In [7]:
cylinder_loss = losses.CylinderLoss()
sphere_loss = losses.SphereLoss()
cone_loss = losses.ConeLoss()
torus_loss = losses.TorusLoss()

def calc_loss(shape_type, ls_shape_params, shape_params):
    ls_shape_params = ls_shape_params.unsqueeze(0).to(device).float()
    if shape_type == 0: #plane 
        loss = 0
    
    elif shape_type == 1: # cylinder
        loss = cylinder_loss(ls_shape_params, shape_params, None)
        loss = sum(loss)
        
    elif shape_type == 2: # sphere
        loss = sphere_loss(ls_shape_params, shape_params, None)
        loss = sum(loss)
        
    elif shape_type == 3: # cone
        loss = cone_loss(ls_shape_params, shape_params, None)
        loss = sum(loss)
        
    elif shape_type == 4: # torus
        loss = torus_loss(ls_shape_params, shape_params, None)
        loss = sum(loss)
        
    return loss

In [8]:
def save_shape_prediction(path, shape_type, shape_params):
    
    if out_path == "":
        return
    
    sizes = {
        "1": 6,
        "2": 7,
        "3": 4,
        "4": 7,
        "5": 8
    }
    
    assert sizes[str(shape_type+1)] == len(shape_params)
    
    with open(path, "w") as F:
        
        F.write(str(shape_type) + "\n")
        
        for param in shape_params:
            
            F.write(str(param) + "\n")

# Train Pipeline

In [None]:
out_path = lambda i : f"/home/ioannis/Desktop/programming/data/SHREC/SHREC2022/dataset/results/pointCloud{i}_prediction.txt"


# Track the acc on the trainset
acc = 0


for i, batch in enumerate(train_loader):
    
    #loading labels
    labels = batch["labels"][:, 0].long().to(device)

    #voxelizing the input
    minknet_input = create_input_batch(
        batch, 
        device=device,
        quantization_size = 0.05
    )
    
    #keeping the transform to project the points back to their original positions
    trans = batch['trans']
    
    #predicting probabilities for each shape_class
    cls_pred = torch.nn.functional.softmax(cls_network(minknet_input), dim=-1).squeeze()
    
    
    #keeping the initial points for distance measuring purposes
    initial_points = batch['initial_points'][0]

    
    thresh = 10
    total_distance = thresh + 1 
    flag = True
    shape_type = 0
    initial_cls_guess = torch.max(cls_pred, dim=-1).indices
    regression_count = 0
    #while the distance between shape and points is large, move on to the next predicted shape
    while total_distance > thresh and flag:
        
        if regression_count > 0:
            print(f"Had to regress! {regression_count} regressions")
            
        regression_count+=1
        
        #taking the i-th network's prediction (in descending order of probability)
        shape_type = torch.max(cls_pred, dim=-1).indices
        
        if shape_type == -1: 
            flag = False # failed to regress a shape
            print("Failed regression!")
            shape_type = initial_cls_guess # using the first guess
        
        #using the appropriate shape regressor to predict the shape parameters
        shape_params = regress_params(shape_type, minknet_input, batch, trans)

        #calculating the distance between the predicted shape and the input point cloud
        # print(initial_points.shape)
        total_distance = distance_points_shape(shape_type, shape_params.squeeze(0), initial_points)
        total_distance /= batch['trans']['norm_factors'][0]
        
        #crossing out the guess
        cls_pred[shape_type.item()] = -1

    #using the predicted shape parameters as initial guess for least squares fitting
    ls_shape_params = ls_fit_shape(initial_points, shape_type, shape_params.squeeze(0))
    #If least squares fitting has exploded then revert back to the network's prediction
    ls_fit_loss = calc_loss(shape_type, ls_shape_params, shape_params)
    if ls_fit_loss < 50:
        shape_params = ls_shape_params
    else:
        print("Least squares method failed, keeping network prediction")
    
    #saving the output
    #save_shape_prediction("", shape_type, shape_params)
    
    acc += (shape_type == labels).item()
    
    if (i+1)%100 == 0:
        print(f"acc: {acc/(i+1)}")

    

acc: 1.0
acc: 1.0
acc: 1.0
acc: 1.0
acc: 1.0
acc: 1.0
acc: 0.9985714285714286
acc: 0.99875
Least squares method failed, keeping network prediction
acc: 0.9977777777777778
acc: 0.998
acc: 0.9981818181818182
acc: 0.9983333333333333
acc: 0.9984615384615385
acc: 0.9985714285714286
acc: 0.9986666666666667
acc: 0.99875
Least squares method failed, keeping network prediction
acc: 0.9982352941176471
acc: 0.9983333333333333
acc: 0.998421052631579
acc: 0.9985
acc: 0.9985714285714286
acc: 0.9986363636363637
acc: 0.9986956521739131
acc: 0.99875
acc: 0.9988
acc: 0.9988461538461538
acc: 0.9988888888888889
Least squares method failed, keeping network prediction
Least squares method failed, keeping network prediction
acc: 0.9982142857142857
acc: 0.9982758620689656
acc: 0.9983333333333333
acc: 0.9983870967741936
acc: 0.9984375
acc: 0.9984848484848485
acc: 0.9985294117647059
acc: 0.9985714285714286
acc: 0.9986111111111111
acc: 0.9986486486486487
acc: 0.9986842105263158
acc: 0.9987179487179487
acc: 0.998

Least squares method failed, keeping network prediction
acc: 0.9842335766423358
acc: 0.9842753623188406
Cone: Least squares method failed
acc: 0.9841726618705036
Least squares method failed, keeping network prediction
Cone: Least squares method failed
acc: 0.9839285714285714
acc: 0.9838297872340426
Least squares method failed, keeping network prediction
Least squares method failed, keeping network prediction
acc: 0.9836619718309859
Cone: Least squares method failed
Had to regress! 1 regressions
Least squares method failed, keeping network prediction
acc: 0.9834265734265735
Cone: Least squares method failed
Least squares method failed, keeping network prediction
acc: 0.9831944444444445
Cone: Least squares method failed
Cone: Least squares method failed
Least squares method failed, keeping network prediction
acc: 0.9827586206896551
Had to regress! 1 regressions
Least squares method failed, keeping network prediction
Least squares method failed, keeping network prediction
Had to regress! 

# Test Pipeline