## Creating a Regressor for the plane class

In [1]:
# imports
from dataset import SHREC2022Primitives
from networks import MinkowskiMR
import transforms as t
import losses as L
import MinkowskiEngine as ME


import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import os

In [2]:
# configuration & hyperparameters
path = "/home/ioannis/Desktop/programming/data/SHREC/SHREC2022/dataset"
batch_size = 128
lr=1e-3
num_epochs = 20
eval_step = 1000000000

In [3]:
def minkowski_collate(list_data):
    coordinates, features, labels = ME.utils.sparse_collate(
        [d['x'] for d in list_data],
        [d['x'] for d in list_data],
        [d['y'].unsqueeze(0) for d in list_data],
        dtype = torch.float32
    )
    
    # collating other data
    norm_factors = []
    shifts = []
    inv_rotations = []
    
    for d in list_data:
        norm_factors.append(d['norm_factor'])
        shifts.append(d['shift'])
        if 'inverse_rotation' in d.keys():
            inv_rotations.append(d['inverse_rotation'])
    
    norm_factors = torch.stack(norm_factors)
    shifts = torch.stack(shifts)
    
    if len(inv_rotations) > 0:
        inv_rotations = torch.stack(inv_rotations)
    
    return {
        "coordinates"   : coordinates, 
        "features"      : features,
        "labels"        : labels,
        "trans": {
            "norm_factors"  : norm_factors,
            "shifts"        : shifts,
            "inv_rotations" : inv_rotations if len(inv_rotations) > 0 else None
        }
    }

In [4]:
# Dataset and dataloader
train_transforms = [t.Initialization(),
                    t.Translate(), 
                    t.SphereNormalization(), 
                    t.RandomRotate(180, 0),
                    t.RandomRotate(180, 1),
                    t.RandomRotate(180, 2),
                   ]#t.GaussianNoise()]

valid_transforms = [t.Translate(), 
                    t.SphereNormalization()]

t_dataset = SHREC2022Primitives(path, 
                                train=True, 
                                valid=False, 
                                valid_split=0.2, 
                                transform=train_transforms,
                                category="all")

train_loader = DataLoader(t_dataset, 
                          batch_size=batch_size, 
                          shuffle=True, 
                          collate_fn=minkowski_collate, 
                         num_workers=8)

v_dataset = SHREC2022Primitives(path, 
                                train=True, 
                                valid=True, 
                                valid_split=0.2, 
                                transform=valid_transforms,
                                category="all")

valid_loader = DataLoader(v_dataset, 
                          batch_size=batch_size, 
                          shuffle=False, 
                          collate_fn=minkowski_collate, 
                          num_workers=8)


Specified split already exists. Using the existing one.
Specified split already exists. Using the existing one.


In [5]:
# network and device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

net = net = MinkowskiMR(in_channel = 3, out_channel = 6).to(device)

Using device: cuda:0


In [6]:
def create_input_batch(batch, device="cuda", quantization_size=0.05):
    batch["coordinates"][:, 1:] = batch["coordinates"][:, 1:] / quantization_size
    return ME.TensorField(
        coordinates=batch["coordinates"],
        features=batch["features"],
        device=device
    )

In [7]:
# optimizer and losses
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
cls_criterion = torch.nn.CrossEntropyLoss()
plane_criterion = L.PlaneLoss()
cylinder_criterion = L.CylinderLoss()
cone_criterion = L.ConeLoss()
sphere_criterion = L.SphereLoss()
torus_criterion = L.TorusLoss()

criteria = [plane_criterion, cylinder_criterion, cone_criterion, sphere_criterion, torus_criterion]

In [8]:
# Training loop

for epoch in range(num_epochs):
    # for loss and accuracy tracking the training set
    m_loss = 0
    
    for batch in tqdm(train_loader):

        optimizer.zero_grad()

        labels = batch["labels"].to(device)
        classes = labels[:,0].long()
        gt = labels[:,1:]

        minknet_input = create_input_batch(
            batch, 
            device=device,
            quantization_size=0.05
        )

        cls, plane, cyl, cone, sphere, torus = net(minknet_input)

        cls_loss = cls_criterion(cls, classes)
        plane_losses = plane_criterion(plane, gt, batch["trans"])
        cylinder_losses = cylinder_criterion(cyl, gt, batch["trans"])
        cone_losses = cone_criterion(cone, gt, batch["trans"])
        sphere_losses = sphere_criterion(sphere, gt, batch["trans"])
        torus_losses = torus_criterion(torus, gt, batch["trans"])

        #B x 5 tensor containing all losses computed for all primitives
        losses = torch.stack([plane_losses, cylinder_losses, cone_losses, sphere_losses, torus_losses]).permute(1,0)

        #Picking out the corresponding loss for each primitive via its class
        #B x 1
        classes = classes.unsqueeze(-1).long()
        loss = torch.gather(losses, 1, classes).mean() + cls_loss

        loss.backward()
        optimizer.step()
        
        m_loss += loss.item()

    m_loss /= len(train_loader)
    
    print(f" Epoch: {epoch} | Training: loss = {m_loss}")
    
    # if (epoch+1) % eval_step == 0:
    #     m_loss = 0
    #     m_normal_loss = 0
    #     m_vertex_loss = 0
    #     net.eval()
    #     with torch.no_grad():
    #         for batch in tqdm(valid_loader):

    #             labels = batch["labels"].to(device)
    #             actual_normal = labels[:,1:4]
    #             actual_vertex = labels[:,4:7]
    #             batch_size = labels.shape[0]
                
    #             net_in = create_input_batch(
    #                 batch, 
    #                 device=device, 
    #                 quantization_size=0.05
    #             )

    #             pred = net(net_in)
              
    #             pred_normal, pred_vertex = transform_outputs(pred, 
    #                                                  batch["norm_factors"], 
    #                                                  batch["shifts"], 
    #                                                  batch["inv_rotations"])

    #             normal_loss = criterion.AxisToAxisLoss(pred_normal, actual_normal)
    #             vertex_loss = criterion.PointToPlaneLoss(pred_vertex, actual_vertex, actual_normal)
        
    #             loss = normal_loss + vertex_loss
          
    #             m_normal_loss += normal_loss.item()
    #             m_vertex_loss += vertex_loss.item()
    #             m_loss += loss.item()
                
    #         m_loss /= len(valid_loader)
    #         m_normal_loss /= len(valid_loader)
    #         m_vertex_loss /= len(valid_loader)

    #         print(f" --------->  Evaluation: loss = {m_loss}")
    #         print(f" Normal loss: {m_normal_loss} | Vertex_loss: {m_vertex_loss}")

    #     # setting network back to training mode
    #     net.train()

print("Done!")

  0%|          | 0/288 [00:05<?, ?it/s]


RuntimeError: The size of tensor a (3) must match the size of tensor b (128) at non-singleton dimension 2