In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from colorcloud.cheng2023TransRVNet import TransVRNet
from colorcloud.cheng2023TransRVNet import TransRVNet_loss
import lightning as L
import numpy as np
import wandb
from lightning.pytorch.loggers import WandbLogger
from datetime import datetime
import torch
import torchvision
from lovasz_losses import lovasz_softmax
from boundary_loss import BoundaryLoss
from colorcloud.behley2019iccv import SemanticKITTIDataset, SphericalProjection, ProjectionTransform, ProjectionVizTransform
from torchvision.transforms import v2
from torch.utils.data import DataLoader

## Data

In [3]:
data_path = '/mnt/c/Users/mathe/OneDrive/Área de Trabalho/FACULDADE/Pesquisa/odometry'
ds = SemanticKITTIDataset(data_path)
frame, label, mask = ds[128]

In [4]:
proj = SphericalProjection(fov_up_deg=12., fov_down_deg=-26., W=1024, H=64)
tfms = v2.Compose([
    ProjectionTransform(proj),
    ProjectionVizTransform(ds.color_map_rgb_np, ds.learning_map_inv_np),
])
ds.set_transform(tfms)
img, label, _ = ds[128]
bs = 1

data_loader = DataLoader(ds, bs)

## Setup

### Loss

In [5]:
def calculate_frequencies(dataset):
    class_frequencies = {i: 0 for i in range(-1, 20)}
    
    for img, labels, _ in dataset:
        # Flatten the label array to count occurrences
        flattened_labels = labels.flatten()
        # Count the occurrences of each class
        unique, counts = np.unique(flattened_labels, return_counts=True)
        # Update the frequency dictionary
        for cls, count in zip(unique, counts):
            class_frequencies[cls] += count

    # change cases where frequency is 0 to 1
    for key, item in class_frequencies.items():
        if item == 0:
            class_frequencies[key] = 1

    class_frequencies = list(class_frequencies.values())
    return class_frequencies

In [6]:
# CLASS_FREQUENCIES = calculate_frequencies(ds)

In [7]:
# CLASS_FREQUENCIES

In [8]:
# lovasz_softmax(probas, labels) --- https://github.com/bermanmaxim/LovaszSoftmax/tree/master
# boundary_loss --- https://github.com/yiskw713/boundary_loss_for_remote_sensing

In [9]:
# Function to calculate class weights
# wc = (ft/fc)^i, where fc is the frequency of class c, and ft is the median of all class frequencies.
def calculate_class_weights(frequencies, exponent):
    median_freq = np.median(frequencies)
    class_weights = (median_freq / frequencies) ** exponent
    return torch.tensor(class_weights, dtype=torch.float32)

EXPONENT_I = 0.5

In [10]:
# # Calculate class weights
# class_weights = calculate_class_weights(CLASS_FREQUENCIES, EXPONENT_I)
# # Weighted cross-entropy loss
# loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights)

In [11]:
#| hide
# nao sei se vale a pena deixar isso como um exemplo
ts = torch.randn(1, 20, 64, 1024)
mask = torch.randint(high = 20, size = (1, 64, 1024))

print(ts.shape)
print(mask.shape)

loss = TransRVNet_loss()
l = loss(ts, mask)
l

torch.Size([1, 20, 64, 1024])
torch.Size([1, 64, 1024])


tensor(6.9845, grad_fn=<AddBackward0>)

### Convolutions' Parameters

In [12]:
p1 = {
    "b1_in": 1,
    "b1_out1": 1,
    "b1_out2": 1,
    "b1_out3": 1,
    "b2_in": 32,
    "b2_out": 32,
    "b3_in": 64,
    "b3_out1": 64,
    "b3_out2": 64,
    "b3_out3": 64,
    "output": 128,
}

p2 = {
    "b1_in": 3,
    "b1_out1": 1,
    "b1_out2": 1,
    "b1_out3": 1,
    "b2_in": 32,
    "b2_out": 32,
    "b3_in": 64,
    "b3_out1": 64,
    "b3_out2": 64,
    "b3_out3": 64,
    "output": 128,
}

## Training

In [13]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [None]:
model = TransVRNet(p1, p2).to(device)
print(model)

In [15]:
dropout = torch.nn.Dropout(p=0.2, inplace=False)
inverter = torchvision.transforms.RandomInvert(p=0.5)
rotation = torchvision.transforms.RandomRotation(degrees=(0, 180))

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-1, weight_decay=0.0001)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, 
                                                max_lr=0.002, 
                                                div_factor=1, 
                                                final_div_factor=10, 
                                                steps_per_epoch=len(data_loader), 
                                                epochs=30)

In [16]:
for batch in data_loader:
    img, label, mask = batch
    ref = img[:,:,:,0].unsqueeze(1)
    dep = img[:,:,:,1].unsqueeze(1)
    xyz = img[:,:,:,2:5]
    xyz = xyz.permute(0, 3, 2, 1)
    print(mask.shape)
    
    break

torch.Size([1, 64, 1024])


In [17]:
# # epochs = 30
# epochs = 1
# for epoch in range(epochs):
#     model.train()
    
#     for batch in data_loader:
#         img, label, mask = batch
#         ref = img[:,:,:,0].unsqueeze(1)
#         dep = img[:,:,:,1].unsqueeze(1)
#         xyz = img[:,:,:,2:5]
#         xyz = xyz.permute(0, 3, 2, 1)
        
#         optimizer.zero_grad()
        
#         # Data augmentation
#         ref = inverter(ref)
#         dep = inverter(dep)
#         xyz = inverter(xyz)

#         # problema em rodar a imagem, teria que rodar a mask tbm
#         # data = rotation(data)
        
#         # Forward pass
#         output = model(ref, dep, xyz)
        
#         # Apply dropout
#         output = dropout(output)
        
#         loss = total_loss(output, mask)
        
#         loss.backward()
        
#         optimizer.step()
        
#         scheduler.step()
        
#         # if batch_idx % log_interval == 0:
#             # print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item()}")
    
#         print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item()}")

# torch.save(model.state_dict(), "trained_transvrnet_model.pth")