In [1]:
# imports
from dataset import SHREC2022Primitives
from networks import MinkowskiFCNN
import transforms as t
from losses import CylinderLoss
import MinkowskiEngine as ME


import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import os

In [2]:
# configuration & hyperparameters
path = "/home/ioannis/Desktop/programming/data/SHREC/SHREC2022/dataset"
batch_size = 128
lr=1e-3
num_epochs = 100
eval_step = 1000000000
regression_targets = 7


checkpoint_path = "/home/ioannis/Desktop/programming/phd/SHREC/SHREC2022/checkpoints"
checkpoint_file = os.path.join(checkpoint_path, "cylinder_100.pth")

In [3]:
def minkowski_collate(list_data):
    coordinates, features, labels = ME.utils.sparse_collate(
        [d['x'] for d in list_data],
        [d['x'] for d in list_data],
        [d['y'].unsqueeze(0) for d in list_data],
        dtype = torch.float32
    )
    
    # collating other data
    norm_factors = []
    shifts = []
    inv_rotations = []
    means = []
    
    for d in list_data:
        norm_factors.append(d['norm_factor'])
        shifts.append(d['shift'])
        if 'inverse_rotation' in d.keys():
            inv_rotations.append(d['inverse_rotation'])
        means.append(d['mean'])
        
    norm_factors = torch.stack(norm_factors)
    shifts = torch.stack(shifts)
    
    if len(inv_rotations) > 0:
        inv_rotations = torch.stack(inv_rotations)
    
    means = torch.stack(means)
    
    return {
        "coordinates"   : coordinates, 
        "features"      : features,
        "labels"        : labels,
        "means"         : means,
        "trans": {"norm_factors"  : norm_factors,
                  "shifts"        : shifts,
                  "inv_rotations" : inv_rotations if len(inv_rotations) > 0 else None
                 }
        }

In [4]:
# Dataset and dataloader
train_transforms = [t.Translate(), 
                    t.SphereNormalization(), 
                    t.Initialization(),
                    t.RandomRotate(180, 0),
                    t.RandomRotate(180, 1),
                    t.RandomRotate(180, 2),
                    t.GaussianNoise(),
                    t.GetMean()]

valid_transforms = [t.Translate(), 
                    t.SphereNormalization()]

t_dataset = SHREC2022Primitives(path, 
                                train=True, 
                                valid=False, 
                                valid_split=0.2, 
                                transform=train_transforms,
                                category="cylinder")

train_loader = DataLoader(t_dataset, 
                          batch_size=batch_size, 
                          shuffle=True, 
                          collate_fn=minkowski_collate, 
                         num_workers=8)

v_dataset = SHREC2022Primitives(path, 
                                train=True, 
                                valid=True, 
                                valid_split=0.2, 
                                transform=valid_transforms,
                                category="cylinder")

valid_loader = DataLoader(v_dataset, 
                          batch_size=batch_size, 
                          shuffle=False, 
                          collate_fn=minkowski_collate, 
                          num_workers=8)

Specified split already exists. Using the existing one.
Specified split already exists. Using the existing one.


In [5]:
# network and device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

net = MinkowskiFCNN(in_channel = 3, out_channel = regression_targets).to(device)
net.load_state_dict(torch.load(checkpoint_file))

Using device: cuda:0


<All keys matched successfully>

In [6]:
def create_input_batch(batch, device="cuda", quantization_size=0.05):
    batch["coordinates"][:, 1:] = batch["coordinates"][:, 1:] / quantization_size
    return ME.TensorField(
        coordinates=batch["coordinates"],
        features=batch["features"],
        device=device
    )

In [7]:
# optimizer and losses
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
criterion = CylinderLoss()

In [8]:
# Training loop

for epoch in range(num_epochs):
    # for loss and accuracy tracking the training set
    m_loss = 0
    m_axis_loss = 0
    m_vertex_loss = 0
    m_radius_loss = 0
    
    for batch in tqdm(train_loader):

        optimizer.zero_grad()

        labels = batch["labels"].to(device)

        minknet_input = create_input_batch(
            batch, 
            device=device,
            quantization_size=0.05
        )

        pred = net(minknet_input)       
        
        gt = labels[:,1:]

        a_loss, v_loss, r_loss = criterion(pred, gt, batch["trans"])
        a_loss = a_loss.mean(0)
        v_loss = v_loss.mean(0)
        r_loss = r_loss.mean(0)
        
        loss = a_loss + v_loss + r_loss
        
        loss.backward()
        optimizer.step()
        
        m_axis_loss += a_loss.item()
        m_vertex_loss += v_loss.item()
        m_radius_loss += r_loss.item()
        m_loss += loss.item()

    m_loss /= len(train_loader)
    m_axis_loss   /= len(train_loader)
    m_vertex_loss /= len(train_loader)
    m_radius_loss /= len(train_loader)
    
    print(f" Epoch: {epoch} | Training: loss = {m_loss}")
    print(f" Axis loss: {m_axis_loss} | Vertex loss: {m_vertex_loss} | Radius loss: {m_radius_loss}")
   

print("Done!")

100%|██████████| 58/58 [00:12<00:00,  4.49it/s]


 Epoch: 0 | Training: loss = 0.734213937973154
 Axis loss: 0.09141412236053369 | Vertex loss: 0.4385768565638312 | Radius loss: 0.20422295236895824


 41%|████▏     | 24/58 [00:07<00:10,  3.31it/s]


KeyboardInterrupt: 

In [None]:
checkpoint_path = "/home/ioannis/Desktop/programming/phd/SHREC/SHREC2022/checkpoints"
checkpoint_file = os.path.join(checkpoint_path, "cylinder_100.pth")
#torch.save(net.state_dict(), checkpoint_file)