In [None]:
%load_ext autoreload
%autoreload 2

In [29]:
from colorcloud.cheng2023TransRVNet import TransVRNet
import lightning as L
import numpy as np
import wandb
from lightning.pytorch.loggers import WandbLogger
from datetime import datetime
import torch
from lovasz_losses import lovasz_softmax
from boundary_loss import BoundaryLoss
from colorcloud.behley2019iccv import SemanticKITTIDataset, SphericalProjection, ProjectionTransform, ProjectionVizTransform
from torchvision.transforms import v2

## Data

In [11]:
data_path = '../../data'
ds = SemanticKITTIDataset(data_path)
frame, label, mask = ds[128]

In [33]:
proj = SphericalProjection(fov_up_deg=12., fov_down_deg=-26., W=1024, H=64)
tfms = v2.Compose([
    ProjectionTransform(proj),
    ProjectionVizTransform(ds.color_map_rgb_np, ds.learning_map_inv_np),
])
ds.set_transform(tfms)
img, label, _ = ds[128]

## Setup

### Loss

In [37]:
# lovasz_softmax(probas, labels) --- https://github.com/bermanmaxim/LovaszSoftmax/tree/master
# boundary_loss --- https://github.com/yiskw713/boundary_loss_for_remote_sensing

In [38]:
# Function to calculate class weights
# wc = (ft/fc)^i, where fc is the frequency of class c, and ft is the median of all class frequencies.
def calculate_class_weights(frequencies, exponent):
    median_freq = np.median(frequencies)
    class_weights = (median_freq / frequencies) ** exponent
    return torch.tensor(class_weights, dtype=torch.float32)

# temporarily random frequencies
CLASS_FREQUENCIES = np.array([100, 200, 300, 50, 25, 60, 70, 80, 90, 120, 130, 140, 150, 160, 170, 180, 190, 200, 210, 220])
EXPONENT_I = 0.5

In [35]:
# Calculate class weights
class_weights = calculate_class_weights(CLASS_FREQUENCIES, EXPONENT_I)
# Weighted cross-entropy loss
loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights)

In [36]:
def total_loss(output, target):
    # The weight of each loss
    Lwce = 1.0
    Lls = 3.0
    Lbd = 1.0

    weighted_cross_entropy_loss = loss_fn(output, target)
    lovasz_loss = lovasz_softmax(output, target)
    boundary_loss = boundary_loss(output, target)
    
    # Return the weighted combination of the three loss functions
    return Lwce*weighted_cross_entropy_loss * Lls*lovasz_loss + Lbd*boundary_loss

### Convolutions' Parameters

In [None]:
p1 = {
    "b1_in": 1,
    "b1_out1": 1,
    "b1_out2": 1,
    "b1_out3": 1,
    "b2_in": 32,
    "b2_out": 32,
    "b3_in": 64,
    "b3_out1": 64,
    "b3_out2": 64,
    "b3_out3": 64,
    "output": 128,
}

p2 = {
    "b1_in": 3,
    "b1_out1": 1,
    "b1_out2": 1,
    "b1_out3": 1,
    "b2_in": 32,
    "b2_out": 32,
    "b3_in": 64,
    "b3_out1": 64,
    "b3_out2": 64,
    "b3_out3": 64,
    "output": 128,
}

## Training

In [None]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

In [None]:
model = TransVRNet(p1, p2).to(device)
print(model)

In [None]:
dropout = torch.nn.Dropout(p=0.2, inplace=False)
inverter = torchvision.transforms.RandomInvert(p=0.5)
rotation = torchvision.transforms.RandomRotation(degrees=(0, 180))

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-1, weight_decay=0.0001)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, 
                                                max_lr=0.002, 
                                                div_factor=1, 
                                                final_div_factor=10, 
                                                steps_per_epoch=len(data_loader), 
                                                epochs=30)

In [None]:
epochs = 30
for epoch in range(epochs):
    model.train()
    
    for batch_idx, (data, target) in enumerate(data_loader):
        optimizer.zero_grad()
        
        # Data augmentation
        data = inverter(data)
        data = rotation(data)
        
        # Forward pass
        output = model(data)
        
        # Apply dropout
        output = dropout(output)
        
        loss = total_loss(output, target)
        
        loss.backward()
        
        optimizer.step()
        
        scheduler.step()
        
        if batch_idx % log_interval == 0:
            print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item()}")

torch.save(model.state_dict(), "trained_transvrnet_model.pth")