In [1]:
# imports
from dataset import SHREC2022Primitives
from networks import MinkowskiFCNN
import transforms as t
from losses import TorusLoss
import MinkowskiEngine as ME


import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import os

# configuration & hyperparameters
path = "/home/ioannis/Desktop/programming/data/SHREC/SHREC2022/dataset"
batch_size = 128
lr=1e-3
num_epochs = 100
eval_step = 1000000000
regression_targets = 8

In [2]:
def minkowski_collate(list_data):
    coordinates, features, labels = ME.utils.sparse_collate(
        [d['x'] for d in list_data],
        [d['x'] for d in list_data],
        [d['y'].unsqueeze(0) for d in list_data],
        dtype = torch.float32
    )
    
    # collating other data
    norm_factors = []
    shifts = []
    inv_rotations = []
    means = []
    
    for d in list_data:
        norm_factors.append(d['norm_factor'])
        shifts.append(d['shift'])
        if 'inverse_rotation' in d.keys():
            inv_rotations.append(d['inverse_rotation'])
        means.append(d['mean'])
        
    norm_factors = torch.stack(norm_factors)
    shifts = torch.stack(shifts)
    
    if len(inv_rotations) > 0:
        inv_rotations = torch.stack(inv_rotations)
    
    means = torch.stack(means)
    
    return {
        "coordinates"   : coordinates, 
        "features"      : features,
        "labels"        : labels,
        "means"         : means,
        "trans": {"norm_factors"  : norm_factors,
                  "shifts"        : shifts,
                  "inv_rotations" : inv_rotations if len(inv_rotations) > 0 else None
                 }
        }


def create_input_batch(batch, device="cuda", quantization_size=0.05):
    batch["coordinates"][:, 1:] = batch["coordinates"][:, 1:] / quantization_size
    return ME.TensorField(
        coordinates=batch["coordinates"],
        features=batch["features"],
        device=device
    )

In [3]:
# Dataset and dataloader
train_transforms = [t.Translate(), 
                    t.SphereNormalization(), 
                    t.Initialization(),
                    t.RandomRotate(180, 0),
                    t.RandomRotate(180, 1),
                    t.RandomRotate(180, 2),
                    t.GaussianNoise(),
                    t.GetMean()]

valid_transforms = [t.Translate(), 
                    t.SphereNormalization()]

t_dataset = SHREC2022Primitives(path, 
                                train=True, 
                                valid=False, 
                                valid_split=0.2, 
                                transform=train_transforms,
                                category="torus")

train_loader = DataLoader(t_dataset, 
                          batch_size=batch_size, 
                          shuffle=True, 
                          collate_fn=minkowski_collate, 
                         num_workers=8)

v_dataset = SHREC2022Primitives(path, 
                                train=True, 
                                valid=True, 
                                valid_split=0.2, 
                                transform=valid_transforms,
                                category="torus")

valid_loader = DataLoader(v_dataset, 
                          batch_size=batch_size, 
                          shuffle=False, 
                          collate_fn=minkowski_collate, 
                          num_workers=8)

Specified split already exists. Using the existing one.
Specified split already exists. Using the existing one.


In [4]:
# network and device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

net = MinkowskiFCNN(in_channel = 3, out_channel = regression_targets).to(device)

Using device: cuda:0


In [5]:
# optimizer and losses
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
criterion = TorusLoss()

In [6]:
# Training loop

for epoch in range(num_epochs):
    # for loss and accuracy tracking the training set
    m_loss = 0
    m_center_loss = 0
    m_axis_loss   = 0
    m_radius_loss = 0
    m_Radius_loss = 0
    
    for batch in tqdm(train_loader):

        optimizer.zero_grad()

        labels = batch["labels"].to(device)

        minknet_input = create_input_batch(
            batch, 
            device=device,
            quantization_size=0.05
        )

        pred = net(minknet_input)       
        
        gt = labels[:,1:]

        a_loss, c_loss, R_loss, r_loss = criterion(pred, gt, batch["trans"])
        c_loss = c_loss.mean(0)
        r_loss = r_loss.mean(0)
        R_loss = R_loss.mean(0)
        a_loss = a_loss.mean(0)
        
        loss = R_loss + r_loss + c_loss + a_loss
        
        loss.backward()
        optimizer.step()
        
        m_center_loss += c_loss.item()
        m_Radius_loss += R_loss.item()
        m_radius_loss += r_loss.item()
        m_axis_loss   += a_loss.item()
        m_loss += loss.item()

    m_loss /= len(train_loader)
    m_center_loss /= len(train_loader)
    m_axis_loss   /= len(train_loader)
    m_radius_loss /= len(train_loader)
    m_Radius_loss /= len(train_loader)
    
    print(f" Epoch: {epoch} | Training: loss = {m_loss}")
    print(f" Center loss: {m_center_loss} | Axis loss: {m_axis_loss} | Radius loss: {m_Radius_loss} | radius loss: {m_radius_loss}")
   

print("Done!")

100%|██████████| 58/58 [00:23<00:00,  2.42it/s]


 Epoch: 0 | Training: loss = 10.03419410771337
 Center loss: 7.037420971640225 | Axis loss: 0.6480024364487879 | Radius loss: 1.594654778982031 | radius loss: 0.7541158697728453


100%|██████████| 58/58 [00:23<00:00,  2.42it/s]


 Epoch: 1 | Training: loss = 7.140303340451471
 Center loss: 5.392571893231622 | Axis loss: 0.6133089425234959 | Radius loss: 0.7253201161992961 | radius loss: 0.40910241161954813


100%|██████████| 58/58 [00:23<00:00,  2.43it/s]


 Epoch: 2 | Training: loss = 6.125022797748961
 Center loss: 4.508732688838038 | Axis loss: 0.6209453097705183 | Radius loss: 0.6683454467304821 | radius loss: 0.3269992935760268


100%|██████████| 58/58 [00:24<00:00,  2.39it/s]


 Epoch: 3 | Training: loss = 4.91712630206141
 Center loss: 3.3710998420057625 | Axis loss: 0.5549524883771765 | Radius loss: 0.6642159233833181 | radius loss: 0.3268580310817423


100%|██████████| 58/58 [00:23<00:00,  2.44it/s]


 Epoch: 4 | Training: loss = 4.18831099312881
 Center loss: 2.775596754304294 | Axis loss: 0.4841079085037626 | Radius loss: 0.6167122531553795 | radius loss: 0.3118940445369688


100%|██████████| 58/58 [00:23<00:00,  2.42it/s]


 Epoch: 5 | Training: loss = 3.7528813871844062
 Center loss: 2.4147005656669878 | Axis loss: 0.4505870049369746 | Radius loss: 0.5899211707813986 | radius loss: 0.2976726522219592


100%|██████████| 58/58 [00:24<00:00,  2.33it/s]


 Epoch: 6 | Training: loss = 3.462562585699147
 Center loss: 2.183929418695384 | Axis loss: 0.4273290700953582 | Radius loss: 0.5637718721710402 | radius loss: 0.28753222833419667


100%|██████████| 58/58 [00:24<00:00,  2.36it/s]


 Epoch: 7 | Training: loss = 3.2376174926757812
 Center loss: 2.0242228898508796 | Axis loss: 0.4028248293646451 | Radius loss: 0.546050613296443 | radius loss: 0.2645191586223142


100%|██████████| 58/58 [00:24<00:00,  2.42it/s]


 Epoch: 8 | Training: loss = 3.0265957240400643
 Center loss: 1.8892425894737244 | Axis loss: 0.3785428975162835 | Radius loss: 0.5106895463220005 | radius loss: 0.24812068661739087


  0%|          | 0/58 [00:03<?, ?it/s]


KeyboardInterrupt: 