In [1]:
#imports
import torch
from torch.utils.data import DataLoader
import MinkowskiEngine as ME
from tqdm import tqdm
import os
import time

import transforms as t
from dataset import SHREC2022Primitives
from networks import MinkowskiFCNN
from utils import minkowski_collate, create_input_batch 
import losses

# configuring paths for data and checkpoints
#data_path = "/home/beastmaster/Desktop/Eleftheria/shrec/dataset/" 
data_path = "/home/ioannis/Desktop/programming/data/SHREC/SHREC2022/dataset"
#checkpoint_path = "/home/beastmaster/Desktop/vlassisgiannis/SHREC2022_PrimitiveRecognition-regressors/checkpoints"
checkpoint_path = "/home/ioannis/Desktop/programming/phd/SHREC/SHREC2022/checkpoints"

cls_checkpoint = os.path.join(checkpoint_path, "classification.pth")
plane_checkpoint = os.path.join(checkpoint_path,"plane.pth")
sphere_checkpoint = os.path.join(checkpoint_path,"sphere_no_radius.pth")
cylinder_checkpoint = os.path.join(checkpoint_path,"cylinder.pth")
cone_checkpoint = os.path.join(checkpoint_path,"cone.pth")
torus_checkpoint = os.path.join(checkpoint_path,"torus.pth")

# General Parameters
valid_split=0.15


# Initializing transforms
# (will use the same augmentation for all networks and tasks)
train_transforms = [t.KeepInitialPoints(),
                    t.Translate(), 
                    t.SphereNormalization(), 
                    t.Initialization(),
                    t.RandomRotate(180, 0),
                    t.RandomRotate(180, 1),
                    t.RandomRotate(180, 2),
                    t.GaussianNoise(),
                    t.GetMean()]

valid_transforms = [t.KeepInitialPoints(),
                    t.Translate(), 
                    t.SphereNormalization(),
                    t.GetMean()]


# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cpu_device = torch.device("cpu")
print(f"Using Device: {device}")

Using Device: cuda


In [2]:
batch_size = 128

# network 
sphere_net = MinkowskiFCNN(in_channel=3,   # point features
                           out_channel=3,  # outputs only the center  
                             ).to(device)

# optimizer
optimizer = torch.optim.Adam(sphere_net.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 
                                            step_size=15,
                                            gamma=0.5)

sphere_loss = losses.SphereLoss()


# dataset 
t_dataset = SHREC2022Primitives(data_path, 
                                train=True, 
                                valid=False, 
                                valid_split=valid_split, 
                                transform=train_transforms,
                                category="sphere")

train_loader = DataLoader(t_dataset, 
                          batch_size=batch_size, 
                          shuffle=True, 
                          collate_fn=minkowski_collate, 
                          num_workers=8)

v_dataset = SHREC2022Primitives(data_path, 
                                train=True, 
                                valid=True, 
                                valid_split=valid_split, 
                                transform=valid_transforms,
                                category="sphere")

valid_loader = DataLoader(v_dataset, 
                          batch_size=1, 
                          shuffle=False, 
                          collate_fn=minkowski_collate, 
                          num_workers=8)

Dataset path:  /home/ioannis/Desktop/programming/data/SHREC/SHREC2022/dataset/training
Creating a new train-validation split.
Dataset path:  /home/ioannis/Desktop/programming/data/SHREC/SHREC2022/dataset/training
Specified split already exists. Using the existing one.


In [3]:
num_epochs = 100
eval_step = 1

start = time.time()

for epoch in range(num_epochs):
    # for loss and accuracy tracking the training set
    m_loss = 0
    m_center_loss = 0
    m_radius_loss = 0
    
    for batch in tqdm(train_loader):
        
        optimizer.zero_grad()
        
        # reading the data and formating them
        labels = batch["labels"].to(device)
        gt = labels[:,1:]
        minknet_input = create_input_batch(
            batch, 
            device=device,
            quantization_size=0.05
        )
        
        # activating network
        pred = sphere_net(minknet_input)       
        
        pred = torch.cat([torch.ones(pred.shape[0], 1, device=pred.device), pred ], dim=-1)
        
        # calculating losses
        c_loss, r_loss = sphere_loss(pred, gt, batch["trans"])
        c_loss = c_loss.mean(0)
                
        loss = c_loss
        
        loss.backward()
        optimizer.step()
        
        m_center_loss += c_loss.item()
        m_loss += loss.item()
    
    # stepping the scheduler
    scheduler.step()
    
    m_loss /= len(train_loader)
    m_center_loss /= len(train_loader)
    
    print(f" Epoch: {epoch} | Training: loss = {m_loss}")
    print(f" Center loss: {m_center_loss}")

    
    # validation
    if (epoch+1)%eval_step == 0:
        sphere_net.eval()
        with torch.no_grad():
            m_loss = 0
            m_center_loss = 0
            m_radius_loss = 0

            for batch in tqdm(valid_loader):
                
                transformed_points = torch.clone(batch['features']).to(device)
                
                # reading the data and formating them
                labels = batch["labels"].to(device)
                gt = labels[:,1:]
                minknet_input = create_input_batch(
                    batch, 
                    device=device,
                    quantization_size=0.05
                )

                # activating network
                pred = sphere_net(minknet_input)       
                
                
                radius = transformed_points - pred
                radius = (radius * radius).sum(-1).sqrt().mean(0)
                
                #print(radius.shape)
                #print(radius.unsqueeze(-1).shape)
                
                pred = torch.cat([radius.unsqueeze(-1).unsqueeze(-1), pred], dim=-1)
                
                # calculating losses
                c_loss, r_loss = sphere_loss(pred, gt, batch["trans"])
                c_loss = c_loss.mean(0)
                r_loss = r_loss.mean(0)

                loss = r_loss + c_loss

                m_center_loss += c_loss.item()
                m_radius_loss += r_loss.item()
                m_loss += loss.item()

            m_loss        /= len(valid_loader)
            m_center_loss /= len(valid_loader)
            m_radius_loss /= len(valid_loader)

            print(f" ------------ | Validation: loss = {m_loss}")
            print(f" Center loss: {m_center_loss} | Radius loss: {m_radius_loss}")
        
        
        sphere_net.train()
        
finish = time.time()

100%|██████████| 62/62 [00:20<00:00,  2.97it/s]


 Epoch: 0 | Training: loss = 4.344136926435655
 Center loss: 4.344136926435655


100%|██████████| 1380/1380 [00:16<00:00, 82.92it/s]


 ------------ | Validation: loss = 3.708247376030203
 Center loss: 2.8216318811050387 | Radius loss: 0.8866154892951392


100%|██████████| 62/62 [00:20<00:00,  3.08it/s]


 Epoch: 1 | Training: loss = 1.6976697666029776
 Center loss: 1.6976697666029776


100%|██████████| 1380/1380 [00:16<00:00, 81.27it/s]


 ------------ | Validation: loss = 1.4876167120180848
 Center loss: 1.0129245204997817 | Radius loss: 0.47469218981921907


100%|██████████| 62/62 [00:19<00:00,  3.11it/s]


 Epoch: 2 | Training: loss = 0.9177618853507503
 Center loss: 0.9177618853507503


100%|██████████| 1380/1380 [00:16<00:00, 81.55it/s]


 ------------ | Validation: loss = 0.8656669523793044
 Center loss: 0.5803250792836163 | Radius loss: 0.285341872596181


100%|██████████| 62/62 [00:20<00:00,  3.04it/s]


 Epoch: 3 | Training: loss = 0.6769296892227665
 Center loss: 0.6769296892227665


100%|██████████| 1380/1380 [00:16<00:00, 82.09it/s]


 ------------ | Validation: loss = 0.6489503210298279
 Center loss: 0.41662271931414 | Radius loss: 0.23232760252341367


100%|██████████| 62/62 [00:20<00:00,  3.05it/s]


 Epoch: 4 | Training: loss = 0.5781110092516868
 Center loss: 0.5781110092516868


100%|██████████| 1380/1380 [00:16<00:00, 82.36it/s]


 ------------ | Validation: loss = 0.6280028944687439
 Center loss: 0.4269693647360512 | Radius loss: 0.20103353105490937


100%|██████████| 62/62 [00:20<00:00,  3.08it/s]


 Epoch: 5 | Training: loss = 0.527587445032212
 Center loss: 0.527587445032212


100%|██████████| 1380/1380 [00:16<00:00, 82.44it/s]


 ------------ | Validation: loss = 0.46300049423194994
 Center loss: 0.29729395146176735 | Radius loss: 0.165706543138753


100%|██████████| 62/62 [00:20<00:00,  3.06it/s]


 Epoch: 6 | Training: loss = 0.4683287984901859
 Center loss: 0.4683287984901859


100%|██████████| 1380/1380 [00:16<00:00, 82.93it/s]


 ------------ | Validation: loss = 0.37122424631043166
 Center loss: 0.2447197255658967 | Radius loss: 0.1265045212136322


100%|██████████| 62/62 [00:20<00:00,  3.09it/s]


 Epoch: 7 | Training: loss = 0.45651137348144283
 Center loss: 0.45651137348144283


100%|██████████| 1380/1380 [00:16<00:00, 81.39it/s]


 ------------ | Validation: loss = 0.5427719432931475
 Center loss: 0.3814562640442223 | Radius loss: 0.1613156810553449


100%|██████████| 62/62 [00:20<00:00,  3.04it/s]


 Epoch: 8 | Training: loss = 0.4108755944236632
 Center loss: 0.4108755944236632


100%|██████████| 1380/1380 [00:16<00:00, 81.41it/s]


 ------------ | Validation: loss = 0.3331965047200936
 Center loss: 0.22108755589655155 | Radius loss: 0.11210894796941526


100%|██████████| 62/62 [00:20<00:00,  3.07it/s]


 Epoch: 9 | Training: loss = 0.3654829720335622
 Center loss: 0.3654829720335622


100%|██████████| 1380/1380 [00:17<00:00, 80.64it/s]


 ------------ | Validation: loss = 0.3796958041161934
 Center loss: 0.2542892497487259 | Radius loss: 0.12540655416331722


100%|██████████| 62/62 [00:20<00:00,  3.08it/s]


 Epoch: 10 | Training: loss = 0.4039816498275726
 Center loss: 0.4039816498275726


100%|██████████| 1380/1380 [00:16<00:00, 81.57it/s]


 ------------ | Validation: loss = 0.3231291612116721
 Center loss: 0.22074040094370687 | Radius loss: 0.10238876075314586


100%|██████████| 62/62 [00:20<00:00,  3.03it/s]


 Epoch: 11 | Training: loss = 0.3892193217912028
 Center loss: 0.3892193217912028


100%|██████████| 1380/1380 [00:16<00:00, 81.78it/s]


 ------------ | Validation: loss = 0.44550663264783813
 Center loss: 0.327149935173878 | Radius loss: 0.11835669743400268


100%|██████████| 62/62 [00:20<00:00,  3.05it/s]


 Epoch: 12 | Training: loss = 0.3420115824668638
 Center loss: 0.3420115824668638


100%|██████████| 1380/1380 [00:16<00:00, 81.93it/s]


 ------------ | Validation: loss = 0.3261685825015656
 Center loss: 0.21643177553668846 | Radius loss: 0.10973680742786243


100%|██████████| 62/62 [00:19<00:00,  3.12it/s]


 Epoch: 13 | Training: loss = 0.31160681814916674
 Center loss: 0.31160681814916674


100%|██████████| 1380/1380 [00:16<00:00, 82.04it/s]


 ------------ | Validation: loss = 0.36680217998476017
 Center loss: 0.25259227230710457 | Radius loss: 0.1142099071766364


100%|██████████| 62/62 [00:20<00:00,  2.96it/s]


 Epoch: 14 | Training: loss = 0.3548576154055134
 Center loss: 0.3548576154055134


100%|██████████| 1380/1380 [00:17<00:00, 77.43it/s]


 ------------ | Validation: loss = 0.3016896654477392
 Center loss: 0.19321500268777764 | Radius loss: 0.10847466083866987


100%|██████████| 62/62 [00:20<00:00,  3.01it/s]


 Epoch: 15 | Training: loss = 0.2811548467124662
 Center loss: 0.2811548467124662


100%|██████████| 1380/1380 [00:16<00:00, 81.81it/s]


 ------------ | Validation: loss = 0.2469463705646205
 Center loss: 0.16454969608032802 | Radius loss: 0.08239667332337089


100%|██████████| 62/62 [00:20<00:00,  3.04it/s]


 Epoch: 16 | Training: loss = 0.25824080647960784
 Center loss: 0.25824080647960784


100%|██████████| 1380/1380 [00:17<00:00, 79.95it/s]


 ------------ | Validation: loss = 0.21377164919283223
 Center loss: 0.14151492008785008 | Radius loss: 0.07225672971872542


100%|██████████| 62/62 [00:21<00:00,  2.93it/s]


 Epoch: 17 | Training: loss = 0.25938890850351703
 Center loss: 0.25938890850351703


100%|██████████| 1380/1380 [00:17<00:00, 79.87it/s]


 ------------ | Validation: loss = 0.21698524227749263
 Center loss: 0.14358483391410698 | Radius loss: 0.07340040826360017


100%|██████████| 62/62 [00:21<00:00,  2.86it/s]


 Epoch: 18 | Training: loss = 0.2921620208409525
 Center loss: 0.2921620208409525


100%|██████████| 1380/1380 [00:16<00:00, 81.47it/s]


 ------------ | Validation: loss = 0.21217898891361825
 Center loss: 0.13908062828700496 | Radius loss: 0.07309836189931018


100%|██████████| 62/62 [00:20<00:00,  3.08it/s]


 Epoch: 19 | Training: loss = 0.2462226402375006
 Center loss: 0.2462226402375006


100%|██████████| 1380/1380 [00:17<00:00, 80.84it/s]


 ------------ | Validation: loss = 0.21579082637302538
 Center loss: 0.14330067261817284 | Radius loss: 0.07249015319760309


100%|██████████| 62/62 [00:20<00:00,  2.96it/s]


 Epoch: 20 | Training: loss = 0.2331562513305295
 Center loss: 0.2331562513305295


100%|██████████| 1380/1380 [00:18<00:00, 74.33it/s]


 ------------ | Validation: loss = 0.18226675093243577
 Center loss: 0.12153107086053132 | Radius loss: 0.06073567969919677


100%|██████████| 62/62 [00:21<00:00,  2.85it/s]


 Epoch: 21 | Training: loss = 0.23754304479206761
 Center loss: 0.23754304479206761


100%|██████████| 1380/1380 [00:16<00:00, 81.62it/s]


 ------------ | Validation: loss = 0.22690292881818386
 Center loss: 0.14858248146378752 | Radius loss: 0.0783204480324783


100%|██████████| 62/62 [00:20<00:00,  3.04it/s]


 Epoch: 22 | Training: loss = 0.24257434616165777
 Center loss: 0.24257434616165777


100%|██████████| 1380/1380 [00:16<00:00, 82.53it/s]


 ------------ | Validation: loss = 0.19128548455664185
 Center loss: 0.1264903119188945 | Radius loss: 0.06479517160962316


100%|██████████| 62/62 [00:20<00:00,  3.08it/s]


 Epoch: 23 | Training: loss = 0.2313363580934463
 Center loss: 0.2313363580934463


100%|██████████| 1380/1380 [00:16<00:00, 82.41it/s]


 ------------ | Validation: loss = 0.19307430887511937
 Center loss: 0.13294474335681128 | Radius loss: 0.0601295640754155


100%|██████████| 62/62 [00:20<00:00,  2.97it/s]


 Epoch: 24 | Training: loss = 0.24103133188140008
 Center loss: 0.24103133188140008


100%|██████████| 1380/1380 [00:16<00:00, 81.74it/s]


 ------------ | Validation: loss = 0.22422706403773138
 Center loss: 0.15008895741019565 | Radius loss: 0.07413810675133353


100%|██████████| 62/62 [00:20<00:00,  3.05it/s]


 Epoch: 25 | Training: loss = 0.24698774132036394
 Center loss: 0.24698774132036394


100%|██████████| 1380/1380 [00:16<00:00, 81.32it/s]


 ------------ | Validation: loss = 0.19613405777580198
 Center loss: 0.12864248928104724 | Radius loss: 0.06749156863518539


100%|██████████| 62/62 [00:20<00:00,  3.01it/s]


 Epoch: 26 | Training: loss = 0.24301684599730275
 Center loss: 0.24301684599730275


100%|██████████| 1380/1380 [00:17<00:00, 80.95it/s]


 ------------ | Validation: loss = 0.1856737592654679
 Center loss: 0.12451527643783977 | Radius loss: 0.06115848261768287


100%|██████████| 62/62 [00:20<00:00,  3.07it/s]


 Epoch: 27 | Training: loss = 0.22822629444060788
 Center loss: 0.22822629444060788


100%|██████████| 1380/1380 [00:16<00:00, 81.81it/s]


 ------------ | Validation: loss = 0.172117325444729
 Center loss: 0.11096439871467158 | Radius loss: 0.06115292718302994


100%|██████████| 62/62 [00:20<00:00,  3.08it/s]


 Epoch: 28 | Training: loss = 0.20402083329616055
 Center loss: 0.20402083329616055


100%|██████████| 1380/1380 [00:17<00:00, 80.65it/s]


 ------------ | Validation: loss = 0.21837819539910214
 Center loss: 0.14339105622980836 | Radius loss: 0.07498713952152102


100%|██████████| 62/62 [00:20<00:00,  3.00it/s]


 Epoch: 29 | Training: loss = 0.21699678633482225
 Center loss: 0.21699678633482225


100%|██████████| 1380/1380 [00:16<00:00, 81.44it/s]


 ------------ | Validation: loss = 0.18731117697180458
 Center loss: 0.12654021531091703 | Radius loss: 0.06077096075848728


100%|██████████| 62/62 [00:20<00:00,  3.08it/s]


 Epoch: 30 | Training: loss = 0.1808675016847349
 Center loss: 0.1808675016847349


100%|██████████| 1380/1380 [00:17<00:00, 80.95it/s]


 ------------ | Validation: loss = 0.16111950914004727
 Center loss: 0.10542221099059344 | Radius loss: 0.05569729875668757


100%|██████████| 62/62 [00:20<00:00,  3.05it/s]


 Epoch: 31 | Training: loss = 0.1794442022519727
 Center loss: 0.1794442022519727


100%|██████████| 1380/1380 [00:17<00:00, 78.90it/s]


 ------------ | Validation: loss = 0.14283750575177065
 Center loss: 0.09187045906966748 | Radius loss: 0.05096704673781583


100%|██████████| 62/62 [00:20<00:00,  3.01it/s]


 Epoch: 32 | Training: loss = 0.17439143167388055
 Center loss: 0.17439143167388055


100%|██████████| 1380/1380 [00:17<00:00, 79.27it/s]


 ------------ | Validation: loss = 0.16062635123009808
 Center loss: 0.10816668931644097 | Radius loss: 0.052459661578625374


100%|██████████| 62/62 [00:20<00:00,  3.00it/s]


 Epoch: 33 | Training: loss = 0.18861606897365663
 Center loss: 0.18861606897365663


100%|██████████| 1380/1380 [00:17<00:00, 81.09it/s]


 ------------ | Validation: loss = 0.15892784573897833
 Center loss: 0.10644821821942047 | Radius loss: 0.052479627329884065


100%|██████████| 62/62 [00:20<00:00,  3.05it/s]


 Epoch: 34 | Training: loss = 0.18164184725572985
 Center loss: 0.18164184725572985


100%|██████████| 1380/1380 [00:17<00:00, 81.03it/s]


 ------------ | Validation: loss = 0.1625674916323012
 Center loss: 0.10845498734859153 | Radius loss: 0.054112503815786694


100%|██████████| 62/62 [00:20<00:00,  3.03it/s]


 Epoch: 35 | Training: loss = 0.1985888788777013
 Center loss: 0.1985888788777013


100%|██████████| 1380/1380 [00:16<00:00, 81.57it/s]


 ------------ | Validation: loss = 0.14889124278245394
 Center loss: 0.09749412418182211 | Radius loss: 0.05139711818963051


100%|██████████| 62/62 [00:20<00:00,  3.00it/s]


 Epoch: 36 | Training: loss = 0.18193903240946033
 Center loss: 0.18193903240946033


100%|██████████| 1380/1380 [00:16<00:00, 81.79it/s]


 ------------ | Validation: loss = 0.14186735475209983
 Center loss: 0.09516075238986388 | Radius loss: 0.04670660260793678


100%|██████████| 62/62 [00:20<00:00,  3.02it/s]


 Epoch: 37 | Training: loss = 0.1823417520330798
 Center loss: 0.1823417520330798


100%|██████████| 1380/1380 [00:16<00:00, 81.38it/s]


 ------------ | Validation: loss = 0.1534425443222405
 Center loss: 0.1013178328406866 | Radius loss: 0.052124712472120575


100%|██████████| 62/62 [00:19<00:00,  3.11it/s]


 Epoch: 38 | Training: loss = 0.18221305250640837
 Center loss: 0.18221305250640837


100%|██████████| 1380/1380 [00:16<00:00, 81.22it/s]


 ------------ | Validation: loss = 0.14485666575197104
 Center loss: 0.09311249391625018 | Radius loss: 0.05174417276609338


100%|██████████| 62/62 [00:20<00:00,  3.05it/s]


 Epoch: 39 | Training: loss = 0.19248083738550062
 Center loss: 0.19248083738550062


100%|██████████| 1380/1380 [00:16<00:00, 82.20it/s]


 ------------ | Validation: loss = 0.15195964509192086
 Center loss: 0.09827767510352081 | Radius loss: 0.053681970028899545


100%|██████████| 62/62 [00:20<00:00,  3.09it/s]


 Epoch: 40 | Training: loss = 0.18796471133828163
 Center loss: 0.18796471133828163


100%|██████████| 1380/1380 [00:16<00:00, 81.82it/s]


 ------------ | Validation: loss = 0.18220587856446493
 Center loss: 0.11999351871934882 | Radius loss: 0.06221235907719837


100%|██████████| 62/62 [00:20<00:00,  3.05it/s]


 Epoch: 41 | Training: loss = 0.16402867376323668
 Center loss: 0.16402867376323668


100%|██████████| 1380/1380 [00:16<00:00, 81.62it/s]


 ------------ | Validation: loss = 0.14549873058644733
 Center loss: 0.09453690989349801 | Radius loss: 0.050961820388557434


100%|██████████| 62/62 [00:20<00:00,  3.05it/s]


 Epoch: 42 | Training: loss = 0.17318607818695805
 Center loss: 0.17318607818695805


100%|██████████| 1380/1380 [00:17<00:00, 79.58it/s]


 ------------ | Validation: loss = 0.1403747825567932
 Center loss: 0.09244990221172422 | Radius loss: 0.047924879848904554


100%|██████████| 62/62 [00:20<00:00,  2.99it/s]


 Epoch: 43 | Training: loss = 0.17244395277192515
 Center loss: 0.17244395277192515


100%|██████████| 1380/1380 [00:16<00:00, 82.04it/s]


 ------------ | Validation: loss = 0.1645916599290512
 Center loss: 0.11606192438396291 | Radius loss: 0.048529735985262994


100%|██████████| 62/62 [00:20<00:00,  3.09it/s]


 Epoch: 44 | Training: loss = 0.21486450563515386
 Center loss: 0.21486450563515386


100%|██████████| 1380/1380 [00:16<00:00, 82.04it/s]


 ------------ | Validation: loss = 0.14318480401315078
 Center loss: 0.09333636323147178 | Radius loss: 0.04984844050099172


100%|██████████| 62/62 [00:20<00:00,  3.10it/s]


 Epoch: 45 | Training: loss = 0.2074757220764314
 Center loss: 0.2074757220764314


100%|██████████| 1380/1380 [00:16<00:00, 82.60it/s]


 ------------ | Validation: loss = 0.1476108130913707
 Center loss: 0.0947046434841468 | Radius loss: 0.052906169123115694


100%|██████████| 62/62 [00:19<00:00,  3.12it/s]


 Epoch: 46 | Training: loss = 0.16439894779074576
 Center loss: 0.16439894779074576


100%|██████████| 1380/1380 [00:16<00:00, 82.28it/s]


 ------------ | Validation: loss = 0.1418289687320741
 Center loss: 0.09118924431552486 | Radius loss: 0.050639725290416535


100%|██████████| 62/62 [00:19<00:00,  3.13it/s]


 Epoch: 47 | Training: loss = 0.17771453662745415
 Center loss: 0.17771453662745415


100%|██████████| 1380/1380 [00:16<00:00, 81.70it/s]


 ------------ | Validation: loss = 0.1356158153663484
 Center loss: 0.08891021118473398 | Radius loss: 0.04670560512122811


100%|██████████| 62/62 [00:19<00:00,  3.14it/s]


 Epoch: 48 | Training: loss = 0.15941198025980302
 Center loss: 0.15941198025980302


100%|██████████| 1380/1380 [00:16<00:00, 82.68it/s]


 ------------ | Validation: loss = 0.12948458428796006
 Center loss: 0.08258995994033615 | Radius loss: 0.04689462361074365


100%|██████████| 62/62 [00:20<00:00,  3.06it/s]


 Epoch: 49 | Training: loss = 0.16474035358236683
 Center loss: 0.16474035358236683


100%|██████████| 1380/1380 [00:16<00:00, 81.72it/s]


 ------------ | Validation: loss = 0.12333416168757654
 Center loss: 0.07949149486899525 | Radius loss: 0.043842667130652505


100%|██████████| 62/62 [00:20<00:00,  3.05it/s]


 Epoch: 50 | Training: loss = 0.16021561730773218
 Center loss: 0.16021561730773218


100%|██████████| 1380/1380 [00:16<00:00, 81.21it/s]


 ------------ | Validation: loss = 0.1262810230088439
 Center loss: 0.0805046912474991 | Radius loss: 0.04577633105532801


100%|██████████| 62/62 [00:20<00:00,  3.01it/s]


 Epoch: 51 | Training: loss = 0.16078934674301454
 Center loss: 0.16078934674301454


100%|██████████| 1380/1380 [00:16<00:00, 82.40it/s]


 ------------ | Validation: loss = 0.1258843200876072
 Center loss: 0.08313902245376137 | Radius loss: 0.04274529751388773


100%|██████████| 62/62 [00:21<00:00,  2.94it/s]


 Epoch: 52 | Training: loss = 0.1686810603545558
 Center loss: 0.1686810603545558


100%|██████████| 1380/1380 [00:17<00:00, 80.15it/s]


 ------------ | Validation: loss = 0.13646109108035423
 Center loss: 0.08637401596430552 | Radius loss: 0.050087075849817364


100%|██████████| 62/62 [00:19<00:00,  3.14it/s]


 Epoch: 53 | Training: loss = 0.16639302274392498
 Center loss: 0.16639302274392498


100%|██████████| 1380/1380 [00:17<00:00, 80.69it/s]


 ------------ | Validation: loss = 0.12157994508547151
 Center loss: 0.07898831123200817 | Radius loss: 0.042591632758337106


100%|██████████| 62/62 [00:20<00:00,  3.05it/s]


 Epoch: 54 | Training: loss = 0.1622261330485344
 Center loss: 0.1622261330485344


100%|██████████| 1380/1380 [00:16<00:00, 82.69it/s]


 ------------ | Validation: loss = 0.12966751694609452
 Center loss: 0.08430013695406556 | Radius loss: 0.04536738094986886


100%|██████████| 62/62 [00:20<00:00,  3.00it/s]


 Epoch: 55 | Training: loss = 0.15110139368522552
 Center loss: 0.15110139368522552


100%|██████████| 1380/1380 [00:17<00:00, 78.52it/s]


 ------------ | Validation: loss = 0.106638234338496
 Center loss: 0.06862885581997588 | Radius loss: 0.038009377397826366


100%|██████████| 62/62 [00:20<00:00,  2.99it/s]


 Epoch: 56 | Training: loss = 0.14318539238264483
 Center loss: 0.14318539238264483


100%|██████████| 1380/1380 [00:17<00:00, 81.07it/s]


 ------------ | Validation: loss = 0.11260048979705628
 Center loss: 0.07377256074273053 | Radius loss: 0.038827930164599275


100%|██████████| 62/62 [00:20<00:00,  3.02it/s]


 Epoch: 57 | Training: loss = 0.1608919792117611
 Center loss: 0.1608919792117611


100%|██████████| 1380/1380 [00:17<00:00, 81.16it/s]


 ------------ | Validation: loss = 0.11203847588997408
 Center loss: 0.07142062524489239 | Radius loss: 0.040617850630561565


100%|██████████| 62/62 [00:20<00:00,  3.06it/s]


 Epoch: 58 | Training: loss = 0.15311066301599627
 Center loss: 0.15311066301599627


100%|██████████| 1380/1380 [00:16<00:00, 81.20it/s]


 ------------ | Validation: loss = 0.11951522609337902
 Center loss: 0.0764533610024217 | Radius loss: 0.04306186524619595


100%|██████████| 62/62 [00:20<00:00,  3.06it/s]


 Epoch: 59 | Training: loss = 0.1511319480355709
 Center loss: 0.1511319480355709


100%|██████████| 1380/1380 [00:16<00:00, 82.22it/s]


 ------------ | Validation: loss = 0.1262023480752575
 Center loss: 0.08161316445003267 | Radius loss: 0.04458918283769335


100%|██████████| 62/62 [00:19<00:00,  3.13it/s]


 Epoch: 60 | Training: loss = 0.16190068940481833
 Center loss: 0.16190068940481833


100%|██████████| 1380/1380 [00:16<00:00, 81.47it/s]


 ------------ | Validation: loss = 0.11214861177091945
 Center loss: 0.07254690912090603 | Radius loss: 0.03960170273444251


100%|██████████| 62/62 [00:20<00:00,  3.06it/s]


 Epoch: 61 | Training: loss = 0.16055025424688094
 Center loss: 0.16055025424688094


100%|██████████| 1380/1380 [00:16<00:00, 82.55it/s]


 ------------ | Validation: loss = 0.1139673811816254
 Center loss: 0.07533137227839451 | Radius loss: 0.038636007722017106


100%|██████████| 62/62 [00:20<00:00,  3.07it/s]


 Epoch: 62 | Training: loss = 0.1588386693789113
 Center loss: 0.1588386693789113


100%|██████████| 1380/1380 [00:17<00:00, 81.06it/s]


 ------------ | Validation: loss = 0.11939593119944523
 Center loss: 0.08029579967038919 | Radius loss: 0.03910013191424041


100%|██████████| 62/62 [00:21<00:00,  2.83it/s]


 Epoch: 63 | Training: loss = 0.1484850840943475
 Center loss: 0.1484850840943475


100%|██████████| 1380/1380 [00:17<00:00, 80.40it/s]


 ------------ | Validation: loss = 0.1094957422515754
 Center loss: 0.07002619251141079 | Radius loss: 0.03946954990924846


100%|██████████| 62/62 [00:20<00:00,  3.08it/s]


 Epoch: 64 | Training: loss = 0.1372168937758092
 Center loss: 0.1372168937758092


100%|██████████| 1380/1380 [00:17<00:00, 81.03it/s]


 ------------ | Validation: loss = 0.1085442775303449
 Center loss: 0.06998324834202177 | Radius loss: 0.038561029791741065


100%|██████████| 62/62 [00:20<00:00,  3.09it/s]


 Epoch: 65 | Training: loss = 0.15053429418513853
 Center loss: 0.15053429418513853


100%|██████████| 1380/1380 [00:16<00:00, 81.19it/s]


 ------------ | Validation: loss = 0.1200040252466325
 Center loss: 0.07790289734060196 | Radius loss: 0.042101127755376666


100%|██████████| 62/62 [00:20<00:00,  3.09it/s]


 Epoch: 66 | Training: loss = 0.14046741036638136
 Center loss: 0.14046741036638136


100%|██████████| 1380/1380 [00:17<00:00, 81.01it/s]


 ------------ | Validation: loss = 0.1085297996534264
 Center loss: 0.06943831002069561 | Radius loss: 0.03909148932482994


100%|██████████| 62/62 [00:20<00:00,  3.03it/s]


 Epoch: 67 | Training: loss = 0.1314806666585707
 Center loss: 0.1314806666585707


100%|██████████| 1380/1380 [00:16<00:00, 81.58it/s]


 ------------ | Validation: loss = 0.119502130595414
 Center loss: 0.07565842904274388 | Radius loss: 0.04384370217355992


100%|██████████| 62/62 [00:20<00:00,  3.02it/s]


 Epoch: 68 | Training: loss = 0.15385126707053953
 Center loss: 0.15385126707053953


100%|██████████| 1380/1380 [00:16<00:00, 81.32it/s]


 ------------ | Validation: loss = 0.10267152650840876
 Center loss: 0.06673369458451328 | Radius loss: 0.03593783193455705


100%|██████████| 62/62 [00:20<00:00,  3.07it/s]


 Epoch: 69 | Training: loss = 0.1625630320800889
 Center loss: 0.1625630320800889


100%|██████████| 1380/1380 [00:16<00:00, 81.73it/s]


 ------------ | Validation: loss = 0.11318812633728975
 Center loss: 0.07458869115747065 | Radius loss: 0.03859943470027172


100%|██████████| 62/62 [00:20<00:00,  3.07it/s]


 Epoch: 70 | Training: loss = 0.13866128592241195
 Center loss: 0.13866128592241195


100%|██████████| 1380/1380 [00:17<00:00, 80.46it/s]


 ------------ | Validation: loss = 0.11165890744730174
 Center loss: 0.07251097511956607 | Radius loss: 0.03914793243411314


100%|██████████| 62/62 [00:20<00:00,  3.09it/s]


 Epoch: 71 | Training: loss = 0.14002672811189004
 Center loss: 0.14002672811189004


100%|██████████| 1380/1380 [00:17<00:00, 80.31it/s]


 ------------ | Validation: loss = 0.11268412205795028
 Center loss: 0.07267047307425367 | Radius loss: 0.0400136484820166


100%|██████████| 62/62 [00:20<00:00,  3.08it/s]


 Epoch: 72 | Training: loss = 0.1345745654356095
 Center loss: 0.1345745654356095


100%|██████████| 1380/1380 [00:16<00:00, 81.36it/s]


 ------------ | Validation: loss = 0.10666167393520364
 Center loss: 0.06865753540872559 | Radius loss: 0.038004138900599935


100%|██████████| 62/62 [00:21<00:00,  2.94it/s]


 Epoch: 73 | Training: loss = 0.13596895481309584
 Center loss: 0.13596895481309584


100%|██████████| 1380/1380 [00:17<00:00, 80.53it/s]


 ------------ | Validation: loss = 0.10630490555353311
 Center loss: 0.06926104172216319 | Radius loss: 0.03704386442563421


100%|██████████| 62/62 [00:20<00:00,  3.05it/s]


 Epoch: 74 | Training: loss = 0.1337866017655019
 Center loss: 0.1337866017655019


100%|██████████| 1380/1380 [00:17<00:00, 81.11it/s]


 ------------ | Validation: loss = 0.11010786765387899
 Center loss: 0.06983586329107588 | Radius loss: 0.04027200424611715


100%|██████████| 62/62 [00:20<00:00,  3.07it/s]


 Epoch: 75 | Training: loss = 0.12113099064557784
 Center loss: 0.12113099064557784


100%|██████████| 1380/1380 [00:16<00:00, 81.89it/s]


 ------------ | Validation: loss = 0.10831673544818077
 Center loss: 0.06855522061628964 | Radius loss: 0.039761513887691925


100%|██████████| 62/62 [00:20<00:00,  3.06it/s]


 Epoch: 76 | Training: loss = 0.1286056501971137
 Center loss: 0.1286056501971137


100%|██████████| 1380/1380 [00:17<00:00, 80.35it/s]


 ------------ | Validation: loss = 0.10531572943411606
 Center loss: 0.06820601073322115 | Radius loss: 0.03710971738923163


100%|██████████| 62/62 [00:20<00:00,  3.05it/s]


 Epoch: 77 | Training: loss = 0.12672580333967362
 Center loss: 0.12672580333967362


100%|██████████| 1380/1380 [00:16<00:00, 81.99it/s]


 ------------ | Validation: loss = 0.10385547246378206
 Center loss: 0.06607118840131106 | Radius loss: 0.03778428346307324


100%|██████████| 62/62 [00:20<00:00,  3.04it/s]


 Epoch: 78 | Training: loss = 0.1303775699869279
 Center loss: 0.1303775699869279


100%|██████████| 1380/1380 [00:17<00:00, 80.96it/s]


 ------------ | Validation: loss = 0.10791329453373771
 Center loss: 0.0725918463975203 | Radius loss: 0.035321447551234106


100%|██████████| 62/62 [00:20<00:00,  3.07it/s]


 Epoch: 79 | Training: loss = 0.1299091947655524
 Center loss: 0.1299091947655524


100%|██████████| 1380/1380 [00:16<00:00, 81.22it/s]


 ------------ | Validation: loss = 0.10212248665444297
 Center loss: 0.06637655477942894 | Radius loss: 0.035745931668063814


100%|██████████| 62/62 [00:20<00:00,  3.03it/s]


 Epoch: 80 | Training: loss = 0.13969343095537154
 Center loss: 0.13969343095537154


100%|██████████| 1380/1380 [00:17<00:00, 80.41it/s]


 ------------ | Validation: loss = 0.10849959916765817
 Center loss: 0.06935162469424645 | Radius loss: 0.039147973662890766


100%|██████████| 62/62 [00:21<00:00,  2.93it/s]


 Epoch: 81 | Training: loss = 0.1233835419820201
 Center loss: 0.1233835419820201


100%|██████████| 1380/1380 [00:17<00:00, 79.63it/s]


 ------------ | Validation: loss = 0.0937954394726715
 Center loss: 0.06096360023107411 | Radius loss: 0.032831840069099935


100%|██████████| 62/62 [00:20<00:00,  3.04it/s]


 Epoch: 82 | Training: loss = 0.12661957548510644
 Center loss: 0.12661957548510644


100%|██████████| 1380/1380 [00:17<00:00, 79.98it/s]


 ------------ | Validation: loss = 0.10071615663608469
 Center loss: 0.06500084557932417 | Radius loss: 0.035715310272197084


100%|██████████| 62/62 [00:20<00:00,  3.07it/s]


 Epoch: 83 | Training: loss = 0.12615628732788947
 Center loss: 0.12615628732788947


100%|██████████| 1380/1380 [00:16<00:00, 82.11it/s]


 ------------ | Validation: loss = 0.10799488341949325
 Center loss: 0.06962310254622961 | Radius loss: 0.03837178049203896


100%|██████████| 62/62 [00:19<00:00,  3.13it/s]


 Epoch: 84 | Training: loss = 0.12250275477286308
 Center loss: 0.12250275477286308


100%|██████████| 1380/1380 [00:17<00:00, 81.04it/s]


 ------------ | Validation: loss = 0.10324284046917033
 Center loss: 0.06678646021835553 | Radius loss: 0.0364563803878434


100%|██████████| 62/62 [00:20<00:00,  3.07it/s]


 Epoch: 85 | Training: loss = 0.1241613072493384
 Center loss: 0.1241613072493384


100%|██████████| 1380/1380 [00:17<00:00, 80.42it/s]


 ------------ | Validation: loss = 0.10177984732999112
 Center loss: 0.06493451876765217 | Radius loss: 0.036845328365177565


100%|██████████| 62/62 [00:20<00:00,  3.05it/s]


 Epoch: 86 | Training: loss = 0.13294370893028476
 Center loss: 0.13294370893028476


100%|██████████| 1380/1380 [00:16<00:00, 81.74it/s]


 ------------ | Validation: loss = 0.10540657703316791
 Center loss: 0.07043554754944627 | Radius loss: 0.03497102944563294


100%|██████████| 62/62 [00:19<00:00,  3.12it/s]


 Epoch: 87 | Training: loss = 0.13170504702195043
 Center loss: 0.13170504702195043


100%|██████████| 1380/1380 [00:17<00:00, 81.08it/s]


 ------------ | Validation: loss = 0.09970322568988788
 Center loss: 0.06410324582462583 | Radius loss: 0.03559997936813377


100%|██████████| 62/62 [00:20<00:00,  3.03it/s]


 Epoch: 88 | Training: loss = 0.12307497210079624
 Center loss: 0.12307497210079624


100%|██████████| 1380/1380 [00:17<00:00, 80.90it/s]


 ------------ | Validation: loss = 0.10191338548543896
 Center loss: 0.06427576348644874 | Radius loss: 0.037637622774896656


100%|██████████| 62/62 [00:19<00:00,  3.13it/s]


 Epoch: 89 | Training: loss = 0.11963087488566676
 Center loss: 0.11963087488566676


100%|██████████| 1380/1380 [00:16<00:00, 81.70it/s]


 ------------ | Validation: loss = 0.09558577465137086
 Center loss: 0.061068577983876764 | Radius loss: 0.034517196695849864


100%|██████████| 62/62 [00:20<00:00,  3.03it/s]


 Epoch: 90 | Training: loss = 0.1255009131085488
 Center loss: 0.1255009131085488


100%|██████████| 1380/1380 [00:16<00:00, 82.90it/s]


 ------------ | Validation: loss = 0.10070256594763519
 Center loss: 0.06426442548626506 | Radius loss: 0.03643814046597841


100%|██████████| 62/62 [00:19<00:00,  3.11it/s]


 Epoch: 91 | Training: loss = 0.1215057774416862
 Center loss: 0.1215057774416862


100%|██████████| 1380/1380 [00:16<00:00, 81.69it/s]


 ------------ | Validation: loss = 0.103330121921391
 Center loss: 0.06706988089382071 | Radius loss: 0.03626024142724581


100%|██████████| 62/62 [00:20<00:00,  3.07it/s]


 Epoch: 92 | Training: loss = 0.1216806712890825
 Center loss: 0.1216806712890825


100%|██████████| 1380/1380 [00:16<00:00, 81.85it/s]


 ------------ | Validation: loss = 0.09650875977013788
 Center loss: 0.06151283727481116 | Radius loss: 0.03499592229632621


100%|██████████| 62/62 [00:19<00:00,  3.12it/s]


 Epoch: 93 | Training: loss = 0.1273737492099885
 Center loss: 0.1273737492099885


100%|██████████| 1380/1380 [00:16<00:00, 81.89it/s]


 ------------ | Validation: loss = 0.10369852900102147
 Center loss: 0.06553991298654485 | Radius loss: 0.03815861636725624


100%|██████████| 62/62 [00:19<00:00,  3.10it/s]


 Epoch: 94 | Training: loss = 0.14140330146877997
 Center loss: 0.14140330146877997


100%|██████████| 1380/1380 [00:16<00:00, 81.80it/s]


 ------------ | Validation: loss = 0.10639341168555447
 Center loss: 0.07008260645198143 | Radius loss: 0.03631080603545839


100%|██████████| 62/62 [00:20<00:00,  3.04it/s]


 Epoch: 95 | Training: loss = 0.1253827670889516
 Center loss: 0.1253827670889516


100%|██████████| 1380/1380 [00:16<00:00, 81.27it/s]


 ------------ | Validation: loss = 0.10286658475727208
 Center loss: 0.06598037674561635 | Radius loss: 0.03688620866832893


100%|██████████| 62/62 [00:20<00:00,  2.95it/s]


 Epoch: 96 | Training: loss = 0.1221843128723483
 Center loss: 0.1221843128723483


100%|██████████| 1380/1380 [00:16<00:00, 81.29it/s]


 ------------ | Validation: loss = 0.0984322553017262
 Center loss: 0.061735465300304326 | Radius loss: 0.03669678986881362


100%|██████████| 62/62 [00:20<00:00,  3.02it/s]


 Epoch: 97 | Training: loss = 0.13449730495772055
 Center loss: 0.13449730495772055


100%|██████████| 1380/1380 [00:17<00:00, 80.75it/s]


 ------------ | Validation: loss = 0.09870327176654
 Center loss: 0.06466938687313332 | Radius loss: 0.034033885460379935


100%|██████████| 62/62 [00:20<00:00,  3.03it/s]


 Epoch: 98 | Training: loss = 0.11932434862659823
 Center loss: 0.11932434862659823


100%|██████████| 1380/1380 [00:16<00:00, 81.67it/s]


 ------------ | Validation: loss = 0.09649243053969732
 Center loss: 0.06111031416876885 | Radius loss: 0.03538211668665725


100%|██████████| 62/62 [00:21<00:00,  2.86it/s]


 Epoch: 99 | Training: loss = 0.11889146676947994
 Center loss: 0.11889146676947994


100%|██████████| 1380/1380 [00:17<00:00, 79.06it/s]

 ------------ | Validation: loss = 0.09633313022930688
 Center loss: 0.061373527929756984 | Radius loss: 0.034959602174916934





In [4]:
torch.save(sphere_net.state_dict(), sphere_checkpoint)
# sending back to cpu to free memory
sphere_net.to(cpu_device)
# time statistics
print(f"Time to train sphere regressor network: {finish - start} sec")

Time to train sphere regressor network: 3737.816478252411 sec
