In [None]:
#imports
import torch
from torch.utils.data import DataLoader
import MinkowskiEngine as ME
from tqdm import tqdm
import os
import time

import transforms as t
from dataset import SHREC2022Primitives
from networks import MinkowskiFCNN
from utils import minkowski_collate, create_input_batch 
import losses

# configuring paths for data and checkpoints
data_path = "/home/beastmaster/Desktop/Eleftheria/shrec/dataset/" 
#data_path = "/home/ioannis/Desktop/programming/data/SHREC/SHREC2022/dataset"
checkpoint_path = "/home/beastmaster/Desktop/vlassisgiannis/SHREC2022_PrimitiveRecognition-regressors/checkpoints"
#checkpoint_path = "/home/ioannis/Desktop/programming/phd/SHREC/SHREC2022/checkpoints"

cls_checkpoint = os.path.join(checkpoint_path, "classification.pth")
plane_checkpoint = os.path.join(checkpoint_path,"plane.pth")
sphere_checkpoint = os.path.join(checkpoint_path,"sphere.pth")
cylinder_checkpoint = os.path.join(checkpoint_path,"cylinder.pth")
cone_checkpoint = os.path.join(checkpoint_path,"cone.pth")
torus_checkpoint = os.path.join(checkpoint_path,"torus.pth")

# General Parameters
valid_split=0.15


# Initializing transforms
# (will use the same augmentation for all networks and tasks)
train_transforms = [t.KeepInitialPoints(),
                    t.Translate(), 
                    t.SphereNormalization(), 
                    t.Initialization(),
                    t.RandomRotate(180, 0),
                    t.RandomRotate(180, 1),
                    t.RandomRotate(180, 2),
                    t.GaussianNoise(),
                    t.GetMean()]

valid_transforms = [t.KeepInitialPoints(),
                    t.Translate(), 
                    t.SphereNormalization(),
                    t.GetMean()]


# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cpu_device = torch.device("cpu")
print(f"Using Device: {device}")

# Classification network

In [None]:
batch_size = 256

# network 
cls_net = MinkowskiFCNN(in_channel=3,   # point features
                        out_channel=5,  # num classes
                        ).to(device)

# optimizer
optimizer = torch.optim.Adam(cls_net.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 
                                            step_size=15,
                                            gamma=0.5)

criterion = torch.nn.CrossEntropyLoss()

# dataset 
t_dataset = SHREC2022Primitives(data_path, 
                                train=True, 
                                valid=False, 
                                valid_split=valid_split, 
                                transform=train_transforms,
                                category="all")

train_loader = DataLoader(t_dataset, 
                          batch_size=batch_size, 
                          shuffle=True, 
                          collate_fn=minkowski_collate, 
                          num_workers=8)

v_dataset = SHREC2022Primitives(data_path, 
                                train=True, 
                                valid=True, 
                                valid_split=valid_split, 
                                transform=valid_transforms,
                                category="all")

valid_loader = DataLoader(v_dataset, 
                          batch_size=batch_size, 
                          shuffle=False, 
                          collate_fn=minkowski_collate, 
                          num_workers=8)

In [None]:
# Training
num_epochs = 200
eval_step = 20

start = time.time()

# Training Loop
for epoch in range(num_epochs):
    
    # for loss and accuracy tracking the training set
    m_loss = 0
    acc = 0
    
    for batch in tqdm(train_loader):
        
        optimizer.zero_grad()
        
        # Getting data and formating them
        labels = batch["labels"][:, 0].long().to(device)
        batch_size = labels.shape[0]
        minknet_input = create_input_batch(
            batch, 
            device=device,
            quantization_size=0.05
        )

        # activating network
        pred = cls_net(minknet_input)
        
        # calculating loss and optimization
        loss = criterion(pred, labels)
        loss.backward()
        optimizer.step()

        # tracking process
        m_loss += loss.item()
        acc += (torch.max(pred, dim=-1).indices == labels).sum().item() / batch_size
    
    # Getting dataset average scores
    m_loss /= len(train_loader)
    acc /= len(train_loader)
    print(f" Epoch: {epoch} | Training: loss = {m_loss} || accuracy = {acc*100}%")

    # step the scheduler
    scheduler.step()
    
    # evaluating progress
    if (epoch+1) % eval_step == 0:
        # setting countrs
        acc = 0
        m_loss = 0
        
        # model to eval mode
        cls_net.eval()
        with torch.no_grad():
            
            for batch in tqdm(valid_loader):
                
                # Getting data and formating them
                labels = batch["labels"][:,0].long().to(device)
                batch_size = labels.shape[0]
                net_in = create_input_batch(
                    batch, 
                    device=device, 
                    quantization_size=0.05
                )
                
                # Activating the network
                pred = cls_net(net_in)
                
                # Calculating loss and acc
                loss = criterion(pred, labels)
                m_loss += loss.item()
                acc += (torch.max(pred, dim=-1).indices == labels).sum().item() / batch_size
            
            # average over the validation dataset
            m_loss /= len(valid_loader)
            acc /= len(valid_loader)
            print(f" --------->  Evaluation: loss = {m_loss} || accuracy = {acc*100}%")
            
        # setting network back to training mode
        cls_net.train()

        
finish = time.time()

In [None]:
# saving network results
torch.save(cls_net.state_dict(), cls_checkpoint)
# sending back to cpu to free memory
cls_net.to(cpu_device)
# time statistics
print(f"Time to train classification network: {finish - start} sec")

# Plane Regressor

In [None]:
batch_size = 128

# network 
plane_net = MinkowskiFCNN(in_channel=3,   # point features
                        out_channel=3,    # IMPORTANT: regression only for the plane normal 
                                          # as the point is calculated as the average 
                                          # of the point cloud
                        ).to(device)

# optimizer
optimizer = torch.optim.Adam(plane_net.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 
                                            step_size=15,
                                            gamma=0.5)

plane_loss = losses.PlaneLoss()


# dataset 
t_dataset = SHREC2022Primitives(data_path, 
                                train=True, 
                                valid=False, 
                                valid_split=valid_split, 
                                transform=train_transforms,
                                category="plane")

train_loader = DataLoader(t_dataset, 
                          batch_size=batch_size, 
                          shuffle=True, 
                          collate_fn=minkowski_collate, 
                          num_workers=8)

v_dataset = SHREC2022Primitives(data_path, 
                                train=True, 
                                valid=True, 
                                valid_split=valid_split, 
                                transform=valid_transforms,
                                category="plane")

valid_loader = DataLoader(v_dataset, 
                          batch_size=batch_size, 
                          shuffle=False, 
                          collate_fn=minkowski_collate, 
                          num_workers=8)

In [None]:
num_epochs = 100
eval_step = 10

start = time.time()

for epoch in range(num_epochs):
    # for loss and accuracy tracking the training set
    m_loss = 0
    m_vertex_loss = 0
    m_normal_loss = 0
    
    for batch in tqdm(train_loader):
        
        optimizer.zero_grad()
        
        # reading the data and formating them
        labels = batch["labels"].to(device)
        gt = labels[:, 1:]
        minknet_input = create_input_batch(
            batch, 
            device=device,
            quantization_size=0.05
        )
        
        # predicting the normal vector
        pred = plane_net(minknet_input)
        
        # getting the average of the points to get the plane vertex
        pred_vertex = batch['means'].to(device)    
        # adding to one vector
        pred = torch.cat([pred, pred_vertex], dim=-1)
                
        # calculating the loss
        normal_loss, vertex_loss = plane_loss(pred, gt, batch['trans'])
                
        loss = normal_loss.mean(0) #+ vertex_loss.mean()
        loss.backward()
        optimizer.step()
        
        # tracking progress
        m_normal_loss += normal_loss.mean(0).item()
        m_vertex_loss += vertex_loss.mean(0).item()
        m_loss += loss.item()
    
    # stepping the scheduler
    scheduler.step()
        
    # epoch average scores   
    m_loss /= len(train_loader)
    m_normal_loss /= len(train_loader)
    m_vertex_loss /= len(train_loader)
    print(f" Epoch: {epoch} | Training: loss = {m_loss}")
    print(f" Normal loss: {m_normal_loss} | Vertex_loss: {m_vertex_loss}")
   

    if (epoch+1)%eval_step == 0:
        plane_net.eval()
        with torch.no_grad():
            m_loss = 0
            m_vertex_loss = 0
            m_normal_loss = 0
    
            for batch in tqdm(valid_loader):

                # reading the data and formating them
                labels = batch["labels"].to(device)
                gt = labels[:, 1:]
                minknet_input = create_input_batch(
                    batch, 
                    device=device,
                    quantization_size=0.05
                )

                # predicting the normal vector
                pred = plane_net(minknet_input)

                # getting the average of the points to get the plane vertex
                pred_vertex = batch['means'].to(device)    
                # adding to one vector
                pred = torch.cat([pred, pred_vertex], dim=-1)

                # calculating the loss
                normal_loss, vertex_loss = plane_loss(pred, gt, batch['trans'])

                # tracking progress
                m_normal_loss += normal_loss.mean(0).item()
                m_vertex_loss += vertex_loss.mean(0).item()
                m_loss += (normal_loss.mean(0).item() + vertex_loss.mean(0).item())

            # epoch average scores   
            m_loss        /= len(valid_loader)
            m_normal_loss /= len(valid_loader)
            m_vertex_loss /= len(valid_loader)
            print(f" ----------- | Validation: loss = {m_loss}")
            print(f" Normal loss: {m_normal_loss} | Vertex_loss: {m_vertex_loss}")
        plane_net.train()
        
finish = time.time()


In [None]:
torch.save(plane_net.state_dict(), plane_checkpoint)
# sending back to cpu to free memory
plane_net.to(cpu_device)
# time statistics
print(f"Time to train plane regressor network: {finish - start} sec")

# Cylinder Regressor

In [None]:
batch_size = 128

# network 
cylinder_net = MinkowskiFCNN(in_channel=3,   # point features
                             out_channel=7,  
                                          
                             ).to(device)

# optimizer
optimizer = torch.optim.Adam(cylinder_net.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 
                                            step_size=15,
                                            gamma=0.5)

cylinder_loss = losses.CylinderLoss()


# dataset 
t_dataset = SHREC2022Primitives(data_path, 
                                train=True, 
                                valid=False, 
                                valid_split=valid_split, 
                                transform=train_transforms,
                                category="cylinder")

train_loader = DataLoader(t_dataset, 
                          batch_size=batch_size, 
                          shuffle=True, 
                          collate_fn=minkowski_collate, 
                          num_workers=8)

v_dataset = SHREC2022Primitives(data_path, 
                                train=True, 
                                valid=True, 
                                valid_split=valid_split, 
                                transform=valid_transforms,
                                category="cylinder")

valid_loader = DataLoader(v_dataset, 
                          batch_size=batch_size, 
                          shuffle=False, 
                          collate_fn=minkowski_collate, 
                          num_workers=8)

In [None]:
num_epochs = 100
eval_step = 10

start = time.time()

for epoch in range(num_epochs):
    # for loss and accuracy tracking the training set
    m_loss = 0
    m_axis_loss = 0
    m_vertex_loss = 0
    m_radius_loss = 0
    
    for batch in tqdm(train_loader):
        
        optimizer.zero_grad()
        
        # reading the data and formating them
        labels = batch["labels"].to(device)
        gt = labels[:,1:]
        minknet_input = create_input_batch(
            batch, 
            device=device,
            quantization_size=0.05
        )
        
        # activating network
        pred = cylinder_net(minknet_input)       
        
        # calculating losses
        a_loss, v_loss, r_loss = cylinder_loss(pred, gt, batch["trans"])
        a_loss = a_loss.mean(0)
        v_loss = v_loss.mean(0)
        r_loss = r_loss.mean(0)
                
        loss = a_loss + v_loss + r_loss
        
        loss.backward()
        optimizer.step()
        
        m_axis_loss += a_loss.item()
        m_vertex_loss += v_loss.item()
        m_radius_loss += r_loss.item()
        m_loss += loss.item()
    
    # stepping the scheduler
    scheduler.step()
    
    m_loss        /= len(train_loader)
    m_axis_loss   /= len(train_loader)
    m_vertex_loss /= len(train_loader)
    m_radius_loss /= len(train_loader)
    
    print(f" Epoch: {epoch} | Training: loss = {m_loss}")
    print(f" Axis loss: {m_axis_loss} | Vertex loss: {m_vertex_loss} | Radius loss: {m_radius_loss}")
   

    # validation
    if (epoch+1)%eval_step == 0:
        cylinder_net.eval()
        with torch.no_grad():
            # for loss and accuracy tracking the training set
            m_loss = 0
            m_axis_loss = 0
            m_vertex_loss = 0
            m_radius_loss = 0

            for batch in tqdm(valid_loader):

                # reading the data and formating them
                labels = batch["labels"].to(device)
                gt = labels[:,1:]
                minknet_input = create_input_batch(
                    batch, 
                    device=device,
                    quantization_size=0.05
                )

                # activating network
                pred = cylinder_net(minknet_input)       

                # calculating losses
                a_loss, v_loss, r_loss = cylinder_loss(pred, gt, batch["trans"])
                a_loss = a_loss.mean(0)
                v_loss = v_loss.mean(0)
                r_loss = r_loss.mean(0)

                loss = a_loss + v_loss + r_loss

                m_axis_loss += a_loss.item()
                m_vertex_loss += v_loss.item()
                m_radius_loss += r_loss.item()
                m_loss += loss.item()

            m_loss        /= len(valid_loader)
            m_axis_loss   /= len(valid_loader)
            m_vertex_loss /= len(valid_loader)
            m_radius_loss /= len(valid_loader)

            print(f" ----------- | Validation: loss = {m_loss}")
            print(f" Axis loss: {m_axis_loss} | Vertex loss: {m_vertex_loss} | Radius loss: {m_radius_loss}")
        cylinder_net.train()
        
finish = time.time()

In [None]:
torch.save(cylinder_net.state_dict(), cylinder_checkpoint)
# sending back to cpu to free memory
cylinder_net.to(cpu_device)
# time statistics
print(f"Time to train cylinder regressor network: {finish - start} sec")

# Sphere Regressor

In [None]:
batch_size = 128

# network 
sphere_net = MinkowskiFCNN(in_channel=3,   # point features
                           out_channel=4,      
                             ).to(device)

# optimizer
optimizer = torch.optim.Adam(sphere_net.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 
                                            step_size=15,
                                            gamma=0.5)

sphere_loss = losses.SphereLoss()


# dataset 
t_dataset = SHREC2022Primitives(data_path, 
                                train=True, 
                                valid=False, 
                                valid_split=valid_split, 
                                transform=train_transforms,
                                category="sphere")

train_loader = DataLoader(t_dataset, 
                          batch_size=batch_size, 
                          shuffle=True, 
                          collate_fn=minkowski_collate, 
                          num_workers=8)

v_dataset = SHREC2022Primitives(data_path, 
                                train=True, 
                                valid=True, 
                                valid_split=valid_split, 
                                transform=valid_transforms,
                                category="sphere")

valid_loader = DataLoader(v_dataset, 
                          batch_size=batch_size, 
                          shuffle=False, 
                          collate_fn=minkowski_collate, 
                          num_workers=8)

In [None]:
num_epochs = 100
eval_step = 10

start = time.time()

for epoch in range(num_epochs):
    # for loss and accuracy tracking the training set
    m_loss = 0
    m_center_loss = 0
    m_radius_loss = 0
    
    for batch in tqdm(train_loader):
        
        optimizer.zero_grad()
        
        # reading the data and formating them
        labels = batch["labels"].to(device)
        gt = labels[:,1:]
        minknet_input = create_input_batch(
            batch, 
            device=device,
            quantization_size=0.05
        )
        
        # activating network
        pred = sphere_net(minknet_input)       
        
        # calculating losses
        c_loss, r_loss = sphere_loss(pred, gt, batch["trans"])
        c_loss = c_loss.mean(0)
        r_loss = r_loss.mean(0)
                
        loss = r_loss + c_loss
        
        loss.backward()
        optimizer.step()
        
        m_center_loss += c_loss.item()
        m_radius_loss += r_loss.item()
        m_loss += loss.item()
    
    # stepping the scheduler
    scheduler.step()
    
    m_loss /= len(train_loader)
    m_center_loss /= len(train_loader)
    m_radius_loss /= len(train_loader)
    
    print(f" Epoch: {epoch} | Training: loss = {m_loss}")
    print(f" Center loss: {m_center_loss} | Radius loss: {m_radius_loss}")

    
    # validation
    if (epoch+1)%eval_step == 0:
        sphere_net.eval()
        with torch.no_grad():
            m_loss = 0
            m_center_loss = 0
            m_radius_loss = 0

            for batch in tqdm(train_loader):

                # reading the data and formating them
                labels = batch["labels"].to(device)
                gt = labels[:,1:]
                minknet_input = create_input_batch(
                    batch, 
                    device=device,
                    quantization_size=0.05
                )

                # activating network
                pred = sphere_net(minknet_input)       

                # calculating losses
                c_loss, r_loss = sphere_loss(pred, gt, batch["trans"])
                c_loss = c_loss.mean(0)
                r_loss = r_loss.mean(0)

                loss = r_loss + c_loss

                m_center_loss += c_loss.item()
                m_radius_loss += r_loss.item()
                m_loss += loss.item()

            m_loss        /= len(valid_loader)
            m_center_loss /= len(valid_loader)
            m_radius_loss /= len(valid_loader)

            print(f" ------------ | Validation: loss = {m_loss}")
            print(f" Center loss: {m_center_loss} | Radius loss: {m_radius_loss}")
        
        
        sphere_net.train()
        
finish = time.time()

In [None]:
torch.save(sphere_net.state_dict(), sphere_checkpoint)
# sending back to cpu to free memory
sphere_net.to(cpu_device)
# time statistics
print(f"Time to train sphere regressor network: {finish - start} sec")

# Cone Regressor

In [None]:
batch_size = 128

# network 
cone_net = MinkowskiFCNN(in_channel=3,   # point features
                           out_channel=7,      
                             ).to(device)

# optimizer
optimizer = torch.optim.Adam(cone_net.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 
                                            step_size=15,
                                            gamma=0.5)

cone_loss = losses.ConeLoss()


# dataset 
t_dataset = SHREC2022Primitives(data_path, 
                                train=True, 
                                valid=False, 
                                valid_split=valid_split, 
                                transform=train_transforms,
                                category="cone")

train_loader = DataLoader(t_dataset, 
                          batch_size=batch_size, 
                          shuffle=True, 
                          collate_fn=minkowski_collate, 
                          num_workers=8)

v_dataset = SHREC2022Primitives(data_path, 
                                train=True, 
                                valid=True, 
                                valid_split=valid_split, 
                                transform=valid_transforms,
                                category="cone")

valid_loader = DataLoader(v_dataset, 
                          batch_size=batch_size, 
                          shuffle=False, 
                          collate_fn=minkowski_collate, 
                          num_workers=8)

In [None]:
num_epochs = 100
eval_step = 10

start = time.time()

for epoch in range(num_epochs):
    # for loss and accuracy tracking the training set
    m_loss = 0
    m_axis_loss = 0
    m_vertex_loss = 0
    m_theta_loss = 0
    
    for batch in tqdm(train_loader):
        
        optimizer.zero_grad()
        
        # reading the data and formating them
        labels = batch["labels"].to(device)
        gt = labels[:,1:]
        minknet_input = create_input_batch(
            batch, 
            device=device,
            quantization_size=0.05
        )
        
        # activating network
        pred = cone_net(minknet_input)       
        
        # calculating losses
        a_loss, v_loss, t_loss = cone_loss(pred, gt, batch["trans"])
        a_loss = a_loss.mean(0)
        v_loss = v_loss.mean(0)
        t_loss = t_loss.mean(0)
        
        loss = a_loss + v_loss + t_loss
        
        loss.backward()
        optimizer.step()
        
        m_axis_loss += a_loss.item()
        m_vertex_loss += v_loss.item()
        m_theta_loss += t_loss.item()
        m_loss += loss.item()
    
    # stepping the scheduler
    scheduler.step()
    
    m_loss /= len(train_loader)
    m_axis_loss /= len(train_loader)
    m_vertex_loss /= len(train_loader)
    m_theta_loss /= len(train_loader)
    
    print(f" Epoch: {epoch} | Training: loss = {m_loss}")
    print(f" Axis loss: {m_axis_loss} | Vertex loss: {m_vertex_loss} | Theta loss: {m_theta_loss}")
 

    
    # validation
    if (epoch+1)%eval_step == 0:
        cone_net.eval()
        with torch.no_grad():
            # for loss and accuracy tracking the training set
            m_loss = 0
            m_axis_loss = 0
            m_vertex_loss = 0
            m_theta_loss = 0

            for batch in tqdm(train_loader):

                # reading the data and formating them
                labels = batch["labels"].to(device)
                gt = labels[:,1:]
                minknet_input = create_input_batch(
                    batch, 
                    device=device,
                    quantization_size=0.05
                )

                # activating network
                pred = cone_net(minknet_input)       

                # calculating losses
                a_loss, v_loss, t_loss = cone_loss(pred, gt, batch["trans"])
                a_loss = a_loss.mean(0)
                v_loss = v_loss.mean(0)
                t_loss = t_loss.mean(0)

                loss = a_loss + v_loss + t_loss

                m_axis_loss += a_loss.item()
                m_vertex_loss += v_loss.item()
                m_theta_loss += t_loss.item()
                m_loss += loss.item()


            m_loss        /= len(valid_loader)
            m_axis_loss   /= len(valid_loader)
            m_vertex_loss /= len(valid_loader)
            m_theta_loss  /= len(valid_loader)

            print(f" ---------- | Validation: loss = {m_loss}")
            print(f" Axis loss: {m_axis_loss} | Vertex loss: {m_vertex_loss} | Theta loss: {m_theta_loss}")

        cone_net.train()
        
finish = time.time()

In [None]:
torch.save(cone_net.state_dict(), cone_checkpoint)
# sending back to cpu to free memory
cone_net.to(cpu_device)
# time statistics
print(f"Time to train cone regressor network: {finish - start} sec")

# Torus Regressor

In [None]:
batch_size = 128

# network 
torus_net = MinkowskiFCNN(in_channel=3,   # point features
                           out_channel=8,      
                             ).to(device)

# optimizer
optimizer = torch.optim.Adam(torus_net.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 
                                            step_size=15,
                                            gamma=0.5)

torus_loss = losses.TorusLoss()


# dataset 
t_dataset = SHREC2022Primitives(data_path, 
                                train=True, 
                                valid=False, 
                                valid_split=valid_split, 
                                transform=train_transforms,
                                category="torus")

train_loader = DataLoader(t_dataset, 
                          batch_size=batch_size, 
                          shuffle=True, 
                          collate_fn=minkowski_collate, 
                          num_workers=8)

v_dataset = SHREC2022Primitives(data_path, 
                                train=True, 
                                valid=True, 
                                valid_split=valid_split, 
                                transform=valid_transforms,
                                category="torus")

valid_loader = DataLoader(v_dataset, 
                          batch_size=batch_size, 
                          shuffle=False, 
                          collate_fn=minkowski_collate, 
                          num_workers=8)

In [None]:
num_epochs = 100
eval_step = 10

start = time.time()

for epoch in range(num_epochs):
    # for loss and accuracy tracking the training set
    m_loss = 0
    m_center_loss = 0
    m_axis_loss   = 0
    m_radius_loss = 0
    m_Radius_loss = 0
    
    for batch in tqdm(train_loader):
        
        optimizer.zero_grad()
        
        # reading the data and formating them
        labels = batch["labels"].to(device)
        gt = labels[:,1:]
        minknet_input = create_input_batch(
            batch, 
            device=device,
            quantization_size=0.05
        )
        
        # activating network
        pred = torus_net(minknet_input)       
        
        # calculating losses
        a_loss, c_loss, R_loss, r_loss = torus_loss(pred, gt, batch["trans"])
        c_loss = c_loss.mean(0)
        r_loss = r_loss.mean(0)
        R_loss = R_loss.mean(0)
        a_loss = a_loss.mean(0)
        
        loss = R_loss + r_loss + c_loss + a_loss
        
        loss.backward()
        optimizer.step()
        
        m_center_loss += c_loss.item()
        m_Radius_loss += R_loss.item()
        m_radius_loss += r_loss.item()
        m_axis_loss   += a_loss.item()
        m_loss += loss.item()
    
    # stepping the scheduler
    scheduler.step()
    
    m_loss /= len(train_loader)
    m_center_loss /= len(train_loader)
    m_axis_loss   /= len(train_loader)
    m_radius_loss /= len(train_loader)
    m_Radius_loss /= len(train_loader)
    
    print(f" Epoch: {epoch} | Training: loss = {m_loss}")
    print(f" Center loss: {m_center_loss} | Axis loss: {m_axis_loss} | Radius loss: {m_Radius_loss} | radius loss: {m_radius_loss}")
 

    
    # validation
    if (epoch+1)%eval_step == 0:
        torus_net.eval()
        with torch.no_grad():
            # for loss and accuracy tracking the training set
            m_loss = 0
            m_center_loss = 0
            m_axis_loss   = 0
            m_radius_loss = 0
            m_Radius_loss = 0

            for batch in tqdm(train_loader):

                # reading the data and formating them
                labels = batch["labels"].to(device)
                gt = labels[:,1:]
                minknet_input = create_input_batch(
                    batch, 
                    device=device,
                    quantization_size=0.05
                )

                # activating network
                pred = torus_net(minknet_input)       

                # calculating losses
                a_loss, c_loss, R_loss, r_loss = torus_loss(pred, gt, batch["trans"])
                c_loss = c_loss.mean(0)
                r_loss = r_loss.mean(0)
                R_loss = R_loss.mean(0)
                a_loss = a_loss.mean(0)

                loss = R_loss + r_loss + c_loss + a_loss


                m_center_loss += c_loss.item()
                m_Radius_loss += R_loss.item()
                m_radius_loss += r_loss.item()
                m_axis_loss   += a_loss.item()
                m_loss += loss.item()

            m_loss        /= len(valid_loader)
            m_center_loss /= len(valid_loader)
            m_axis_loss   /= len(valid_loader)
            m_radius_loss /= len(valid_loader)
            m_Radius_loss /= len(valid_loader)

            print(f" ---------- | Validation: loss = {m_loss}")
            print(f" Center loss: {m_center_loss} | Axis loss: {m_axis_loss} | Radius loss: {m_Radius_loss} | radius loss: {m_radius_loss}")

        torus_net.train()
        
finish = time.time()

In [None]:
torch.save(torus_net.state_dict(), torus_checkpoint)
# sending back to cpu to free memory
torus_net.to(cpu_device)
# time statistics
print(f"Time to train torus regressor network: {finish - start} sec")