In [1]:
# imports
from dataset import SHREC2022Primitives
from networks import MinkowskiFCNN
import transforms as t
from losses import ConeLoss
from cone_fit import cone_fit, Cone
import MinkowskiEngine as ME

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
import numpy as np
from utils import minkowski_collate, create_input_batch

import geomfitty
from geomfitty import geom3d
from geomfitty import plot
from geomfitty import fit3d
# configuration & hyperparameters
path = "/home/ioannis/Desktop/programming/data/SHREC/SHREC2022/dataset"
batch_size = 1
lr=1e-3
num_epochs = 100
eval_step = 1000000000
regression_targets = 7

checkpoint_path = "/home/ioannis/Desktop/programming/phd/SHREC/SHREC2022/checkpoints"
checkpoint_file = os.path.join(checkpoint_path, "cone_100.pth")

In [2]:
# Dataset and dataloader
train_transforms = [t.KeepInitialPoints(),
                    t.Translate(), 
                    t.SphereNormalization(), 
                    t.Initialization(),
                    t.RandomRotate(180, 0),
                    t.RandomRotate(180, 1),
                    t.RandomRotate(180, 2),
                    t.GaussianNoise(),
                    t.GetMean()]

valid_transforms = [t.KeepInitialPoints(),
                    t.Translate(), 
                    t.SphereNormalization(),
                    t.GetMean()]

t_dataset = SHREC2022Primitives(path, 
                                train=True, 
                                valid=False, 
                                valid_split=0.2, 
                                transform=train_transforms,
                                category="cone")

train_loader = DataLoader(t_dataset, 
                          batch_size=batch_size, 
                          shuffle=True, 
                          collate_fn=minkowski_collate, 
                         num_workers=8)

v_dataset = SHREC2022Primitives(path, 
                                train=True, 
                                valid=True, 
                                valid_split=0.2, 
                                transform=valid_transforms,
                                category="cone")

valid_loader = DataLoader(v_dataset, 
                          batch_size=batch_size, 
                          shuffle=False, 
                          collate_fn=minkowski_collate, 
                          num_workers=8)

Specified split already exists. Using the existing one.
Specified split already exists. Using the existing one.


In [3]:
# network and device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

net = MinkowskiFCNN(in_channel = 3, out_channel = regression_targets).to(device)
net.load_state_dict(torch.load(checkpoint_file))
net.eval()
# to use the "transform_cylinder_outputs" method
l = ConeLoss()

Using device: cuda:0


In [4]:
for i, sample in enumerate(valid_loader):
    print(f"Iteration: {i}")
    
    minknet_input = create_input_batch(
            sample, 
            device=device,
            quantization_size=0.05
        )

    gt = sample['labels'][:,1:].to(device)
    pred = net(minknet_input)
    theta, axis, vertex = l.transform_cone_outputs(pred, sample['trans'])
    
    theta = theta.detach().cpu().squeeze().numpy()
    axis   = axis.detach().cpu().squeeze().numpy()
    vertex = vertex.detach().cpu().squeeze().numpy()
    
    points = sample['initial_points'][0].numpy()
    
    #Calculating loss before optimization step
    init_a_loss, init_v_loss, init_t_loss = l(pred, gt, sample['trans'])
    init_loss = init_a_loss + init_v_loss + init_t_loss
    
    #Performing optimization step
    cone = Cone(theta, axis, vertex)
    print(cone.theta)
    cone = cone_fit(points, initial_guess=cone, weights=None)
    try:
        print(cone.theta)
    except:
        break
    theta, axis, vertex = torch.tensor([cone.theta]), torch.tensor(cone.axis), torch.tensor(cone.vertex)
    
    #Calculating loss after optimization step
    cone_params = torch.cat((theta, axis, vertex), dim=0).double().unsqueeze(0).to(device)
    a_loss, v_loss, t_loss = l(cone_params, gt.double(), None)
    final_loss = a_loss + v_loss + t_loss
    
    
    print("LOSS BEFORE OPTIMIZATION: ", init_loss.item())
    print("LOSS AFTER OPTIMIZATION: ", final_loss.item())
    print("IMPROVEMENT: ", init_loss.item() - final_loss.item())
    

Iteration: 0
0.8153984546661377
0.88959256693449
LOSS BEFORE OPTIMIZATION:  0.05725323036313057
LOSS AFTER OPTIMIZATION:  4.507418952413287e-16
IMPROVEMENT:  0.05725323036313012
Iteration: 1
0.7575044631958008
0.6304199307349779
LOSS BEFORE OPTIMIZATION:  0.16795992851257324
LOSS AFTER OPTIMIZATION:  4.3461391537743115e-16
IMPROVEMENT:  0.1679599285125728
Iteration: 2
0.8053709268569946
0.7719867189782142
LOSS BEFORE OPTIMIZATION:  0.6763876080513
LOSS AFTER OPTIMIZATION:  0.153767844650815
IMPROVEMENT:  0.522619763400485
Iteration: 3
0.8345319628715515
0.8633250713257149
LOSS BEFORE OPTIMIZATION:  0.11112755537033081
LOSS AFTER OPTIMIZATION:  1.6707717051971716e-16
IMPROVEMENT:  0.11112755537033064
Iteration: 4
0.6834679245948792
0.6952131085971137
LOSS BEFORE OPTIMIZATION:  0.1807001382112503
LOSS AFTER OPTIMIZATION:  6.901558587439122e-15
IMPROVEMENT:  0.1807001382112434
Iteration: 5
0.696857213973999
0.6452619470344679
LOSS BEFORE OPTIMIZATION:  0.12755554914474487
LOSS AFTER OPTIM

0.6354731741892503
LOSS BEFORE OPTIMIZATION:  0.20997311174869537
LOSS AFTER OPTIMIZATION:  0.0008636690297926723
IMPROVEMENT:  0.2091094427189027
Iteration: 47
0.7778702974319458
0.9500170213874621
LOSS BEFORE OPTIMIZATION:  0.037407297641038895
LOSS AFTER OPTIMIZATION:  0.05658526706090922
IMPROVEMENT:  -0.019177969419870322
Iteration: 48
0.7361711263656616
0.6324924750401728
LOSS BEFORE OPTIMIZATION:  0.43823081254959106
LOSS AFTER OPTIMIZATION:  0.03200951745069929
IMPROVEMENT:  0.40622129509889177
Iteration: 49
0.6954922676086426
1.0436952263420094
LOSS BEFORE OPTIMIZATION:  1.279305100440979
LOSS AFTER OPTIMIZATION:  2.6607304594302095e-14
IMPROVEMENT:  1.2793051004409524
Iteration: 50
0.7513407468795776
0.8772688420412519
LOSS BEFORE OPTIMIZATION:  0.07941852509975433
LOSS AFTER OPTIMIZATION:  0.0005323304329306689
IMPROVEMENT:  0.07888619466682366
Iteration: 51
0.7457627654075623
0.8230482408639728
LOSS BEFORE OPTIMIZATION:  0.12903137505054474
LOSS AFTER OPTIMIZATION:  4.41977

0.5495328191596421
LOSS BEFORE OPTIMIZATION:  0.4002663493156433
LOSS AFTER OPTIMIZATION:  0.00051778107353635
IMPROVEMENT:  0.39974856824210697
Iteration: 96
0.7802945375442505
1.000535751762639
LOSS BEFORE OPTIMIZATION:  1.4595601558685303
LOSS AFTER OPTIMIZATION:  0.027372793714958578
IMPROVEMENT:  1.4321873621535717
Iteration: 97
0.8393027782440186
1.005292450515654
LOSS BEFORE OPTIMIZATION:  0.17613251507282257
LOSS AFTER OPTIMIZATION:  0.02075909131016258
IMPROVEMENT:  0.15537342376266
Iteration: 98
0.6794394254684448
0.6437462252909657
LOSS BEFORE OPTIMIZATION:  0.06585343927145004
LOSS AFTER OPTIMIZATION:  0.005918368734045871
IMPROVEMENT:  0.05993507053740417
Iteration: 99
0.7752371430397034
0.6619409656006459
LOSS BEFORE OPTIMIZATION:  2.034926414489746
LOSS AFTER OPTIMIZATION:  5.229636786644623e-14
IMPROVEMENT:  2.0349264144896937
Iteration: 100
0.7043266296386719
0.8790855409638518
LOSS BEFORE OPTIMIZATION:  3.303422212600708
LOSS AFTER OPTIMIZATION:  3.888335044526703e-14

0.9782450088655309
LOSS BEFORE OPTIMIZATION:  0.29714536666870117
LOSS AFTER OPTIMIZATION:  0.00034253721832433346
IMPROVEMENT:  0.29680282945037684
Iteration: 146
0.603639543056488
0.5736946727587996
LOSS BEFORE OPTIMIZATION:  0.06052420660853386
LOSS AFTER OPTIMIZATION:  0.001384310548799614
IMPROVEMENT:  0.05913989605973424
Iteration: 147
0.7403577566146851
0.7760111142724011
LOSS BEFORE OPTIMIZATION:  0.2640165388584137
LOSS AFTER OPTIMIZATION:  0.004401539457467825
IMPROVEMENT:  0.25961499940094585
Iteration: 148
0.815762460231781
0.8757575264915133
LOSS BEFORE OPTIMIZATION:  0.0085153104737401
LOSS AFTER OPTIMIZATION:  4.983107112716357e-05
IMPROVEMENT:  0.008465479402612937
Iteration: 149
0.8002519607543945
0.874599142382193
LOSS BEFORE OPTIMIZATION:  0.40641212463378906
LOSS AFTER OPTIMIZATION:  0.01374778221862674
IMPROVEMENT:  0.39266434241516235
Iteration: 150
0.7879856824874878
1.0471975702289258
LOSS BEFORE OPTIMIZATION:  1.4817827939987183
LOSS AFTER OPTIMIZATION:  5.1429

0.5288522764419997
LOSS BEFORE OPTIMIZATION:  0.022110572084784508
LOSS AFTER OPTIMIZATION:  3.9449296211593156e-15
IMPROVEMENT:  0.022110572084780563
Iteration: 193
0.7985063791275024
0.8625230684611574
LOSS BEFORE OPTIMIZATION:  0.5389360785484314
LOSS AFTER OPTIMIZATION:  0.005961435036023154
IMPROVEMENT:  0.5329746435124082
Iteration: 194
0.7053025364875793
0.5681626070834541
LOSS BEFORE OPTIMIZATION:  0.08595572412014008
LOSS AFTER OPTIMIZATION:  0.010286657811432476
IMPROVEMENT:  0.0756690663087076
Iteration: 195
0.918201208114624
0.9334074922693476
LOSS BEFORE OPTIMIZATION:  0.03132614493370056
LOSS AFTER OPTIMIZATION:  0.0007977222242627322
IMPROVEMENT:  0.03052842270943783
Iteration: 196
0.7082328796386719
0.7569725668899763
LOSS BEFORE OPTIMIZATION:  0.15119387209415436
LOSS AFTER OPTIMIZATION:  0.0002456436055850173
IMPROVEMENT:  0.15094822848856934
Iteration: 197
0.7820315361022949
0.7067507681060311
LOSS BEFORE OPTIMIZATION:  0.021654276177287102
LOSS AFTER OPTIMIZATION:  

ValueError: PositiveNumber must be initialized with a positive number

In [None]:
cone

In [7]:
from utils import visualize_pointcloud
visualize_pointcloud(sample["initial_points"][0])