In [1]:
# imports
from dataset import SHREC2022Primitives
from networks import MinkowskiFCNN
import transforms as t
from losses import ConeLoss
import MinkowskiEngine as ME


import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import os

# configuration & hyperparameters
path = "/home/ioannis/Desktop/programming/data/SHREC/SHREC2022/dataset"
batch_size = 128
lr=1e-3
num_epochs = 100
eval_step = 1000000000
regression_targets = 7

In [2]:
def minkowski_collate(list_data):
    coordinates, features, labels = ME.utils.sparse_collate(
        [d['x'] for d in list_data],
        [d['x'] for d in list_data],
        [d['y'].unsqueeze(0) for d in list_data],
        dtype = torch.float32
    )
    
    # collating other data
    norm_factors = []
    shifts = []
    inv_rotations = []
    means = []
    
    for d in list_data:
        norm_factors.append(d['norm_factor'])
        shifts.append(d['shift'])
        if 'inverse_rotation' in d.keys():
            inv_rotations.append(d['inverse_rotation'])
        means.append(d['mean'])
        
    norm_factors = torch.stack(norm_factors)
    shifts = torch.stack(shifts)
    
    if len(inv_rotations) > 0:
        inv_rotations = torch.stack(inv_rotations)
    
    means = torch.stack(means)
    
    return {
        "coordinates"   : coordinates, 
        "features"      : features,
        "labels"        : labels,
        "means"         : means,
        "trans": {"norm_factors"  : norm_factors,
                  "shifts"        : shifts,
                  "inv_rotations" : inv_rotations if len(inv_rotations) > 0 else None
                 }
        }


def create_input_batch(batch, device="cuda", quantization_size=0.05):
    batch["coordinates"][:, 1:] = batch["coordinates"][:, 1:] / quantization_size
    return ME.TensorField(
        coordinates=batch["coordinates"],
        features=batch["features"],
        device=device
    )

In [3]:
# Dataset and dataloader
train_transforms = [t.Translate(), 
                    t.SphereNormalization(), 
                    t.Initialization(),
                    t.RandomRotate(180, 0),
                    t.RandomRotate(180, 1),
                    t.RandomRotate(180, 2),
                    t.GaussianNoise(),
                    t.GetMean()]

valid_transforms = [t.Translate(), 
                    t.SphereNormalization()]

t_dataset = SHREC2022Primitives(path, 
                                train=True, 
                                valid=False, 
                                valid_split=0.2, 
                                transform=train_transforms,
                                category="cone")

train_loader = DataLoader(t_dataset, 
                          batch_size=batch_size, 
                          shuffle=True, 
                          collate_fn=minkowski_collate, 
                         num_workers=8)

v_dataset = SHREC2022Primitives(path, 
                                train=True, 
                                valid=True, 
                                valid_split=0.2, 
                                transform=valid_transforms,
                                category="cone")

valid_loader = DataLoader(v_dataset, 
                          batch_size=batch_size, 
                          shuffle=False, 
                          collate_fn=minkowski_collate, 
                          num_workers=8)

Creating a new train-validation split.
Specified split already exists. Using the existing one.


In [4]:
# network and device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
net = MinkowskiFCNN(in_channel = 3, out_channel = regression_targets).to(device)

Using device: cuda:0


In [5]:
# optimizer and losses
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
criterion = ConeLoss()

In [6]:
# Training loop

for epoch in range(num_epochs):
    # for loss and accuracy tracking the training set
    m_loss = 0
    m_axis_loss = 0
    m_vertex_loss = 0
    m_theta_loss = 0
    
    for batch in tqdm(train_loader):

        optimizer.zero_grad()

        labels = batch["labels"].to(device)

        minknet_input = create_input_batch(
            batch, 
            device=device,
            quantization_size=0.05
        )

        pred = net(minknet_input)       
        
        gt = labels[:,1:]

        a_loss, v_loss, t_loss = criterion(pred, gt, batch["trans"])
        a_loss = a_loss.mean(0)
        v_loss = v_loss.mean(0)
        t_loss = t_loss.mean(0)
        
        loss = a_loss + v_loss + t_loss
        
        loss.backward()
        optimizer.step()
        
        m_axis_loss += a_loss.item()
        m_vertex_loss += v_loss.item()
        m_theta_loss += t_loss.item()
        m_loss += loss.item()

    m_loss /= len(train_loader)
    m_axis_loss /= len(train_loader)
    m_vertex_loss /= len(train_loader)
    m_theta_loss /= len(train_loader)
    
    print(f" Epoch: {epoch} | Training: loss = {m_loss}")
    print(f" Axis loss: {m_axis_loss} | Vertex loss: {m_vertex_loss} | Theta loss: {m_theta_loss}")
   

print("Done!")

100%|██████████| 58/58 [00:17<00:00,  3.37it/s]


 Epoch: 0 | Training: loss = 2.2262258735196343
 Axis loss: 0.5998672431912916 | Vertex loss: 1.5168131774869458 | Theta loss: 0.10954545386906328


100%|██████████| 58/58 [00:16<00:00,  3.62it/s]


 Epoch: 1 | Training: loss = 1.7048310514154106
 Axis loss: 0.4597485301823452 | Vertex loss: 1.2007674636511967 | Theta loss: 0.044315043226655186


100%|██████████| 58/58 [00:17<00:00,  3.22it/s]


 Epoch: 2 | Training: loss = 1.4480628083492149
 Axis loss: 0.43555011625947626 | Vertex loss: 0.9707780583151455 | Theta loss: 0.041734626034981216


100%|██████████| 58/58 [00:17<00:00,  3.38it/s]


 Epoch: 3 | Training: loss = 1.2463349103927612
 Axis loss: 0.41396724869465007 | Vertex loss: 0.7978138594791807 | Theta loss: 0.034553805558845914


100%|██████████| 58/58 [00:16<00:00,  3.50it/s]


 Epoch: 4 | Training: loss = 1.1613707090246266
 Axis loss: 0.4032480464927081 | Vertex loss: 0.7267983072790606 | Theta loss: 0.03132436206114703


100%|██████████| 58/58 [00:17<00:00,  3.35it/s]


 Epoch: 5 | Training: loss = 1.0915629195755925
 Axis loss: 0.3924230470739562 | Vertex loss: 0.6677274262082988 | Theta loss: 0.031412448669815886


100%|██████████| 58/58 [00:17<00:00,  3.35it/s]


 Epoch: 6 | Training: loss = 1.07267741499276
 Axis loss: 0.3866957297612881 | Vertex loss: 0.6570151962082962 | Theta loss: 0.028966496152610613


100%|██████████| 58/58 [00:17<00:00,  3.36it/s]


 Epoch: 7 | Training: loss = 1.0191281982536973
 Axis loss: 0.37719480950256873 | Vertex loss: 0.6128022830034124 | Theta loss: 0.029131103050092172


100%|██████████| 58/58 [00:17<00:00,  3.29it/s]


 Epoch: 8 | Training: loss = 1.001510390947605
 Axis loss: 0.37132327412736826 | Vertex loss: 0.6017595519279612 | Theta loss: 0.028427562836943


100%|██████████| 58/58 [00:16<00:00,  3.50it/s]


 Epoch: 9 | Training: loss = 0.973976738494018
 Axis loss: 0.3655525892972946 | Vertex loss: 0.5784382337126238 | Theta loss: 0.029985919113046135


100%|██████████| 58/58 [00:16<00:00,  3.57it/s]


 Epoch: 10 | Training: loss = 0.9647810829096827
 Axis loss: 0.3658490699940714 | Vertex loss: 0.5718952509863623 | Theta loss: 0.027036763920352376


100%|██████████| 58/58 [00:17<00:00,  3.37it/s]


 Epoch: 11 | Training: loss = 0.9461602106176573
 Axis loss: 0.3595932547388406 | Vertex loss: 0.560010971694157 | Theta loss: 0.026555977312141453


100%|██████████| 58/58 [00:17<00:00,  3.41it/s]


 Epoch: 12 | Training: loss = 0.9310127763912596
 Axis loss: 0.3526539473698057 | Vertex loss: 0.5519958364552465 | Theta loss: 0.02636298444122076


100%|██████████| 58/58 [00:16<00:00,  3.45it/s]


 Epoch: 13 | Training: loss = 0.9204617337933902
 Axis loss: 0.3437987458089302 | Vertex loss: 0.5510395667676268 | Theta loss: 0.02562341578947059


100%|██████████| 58/58 [00:18<00:00,  3.20it/s]


 Epoch: 14 | Training: loss = 0.9136967360973358
 Axis loss: 0.3419266014263548 | Vertex loss: 0.5457602174117647 | Theta loss: 0.026009920920277464


100%|██████████| 58/58 [00:17<00:00,  3.39it/s]


 Epoch: 15 | Training: loss = 0.8905877532630131
 Axis loss: 0.3380506907043786 | Vertex loss: 0.5256760500628372 | Theta loss: 0.02686100796764267


100%|██████████| 58/58 [00:17<00:00,  3.36it/s]


 Epoch: 16 | Training: loss = 0.8883717460878964
 Axis loss: 0.33169065210325965 | Vertex loss: 0.5304136348181757 | Theta loss: 0.026267457689190733


 55%|█████▌    | 32/58 [00:10<00:08,  3.08it/s]


KeyboardInterrupt: 