In [None]:
%load_ext autoreload
%autoreload 2

In [4]:
from colorcloud.cheng2023TransRVNet import TransVRNet
import lightning as L
# import wandb
from lightning.pytorch.loggers import WandbLogger
from datetime import datetime
import torch
from lovasz_losses import lovasz_softmax
from colorcloud.behley2019iccv import SemanticKITTIDataset, SphericalProjection, ProjectionToTensorTransform

## Data

In [None]:
data_path = '../../Cloud2DImageConverter/point_clouds/semantic_kitti/'
ds = SemanticKITTIDataset(data_path)

tfms = v2.Compose([
    SphericalProjectionTransform(fov_up_deg=4., fov_down_deg=-26., W=1024, H=64),
    ProjectionToTensorTransform(),
])
ds.set_transform(tfms)

bs = 3
dl = DataLoader(ds, bs)

## Setup

- **Data Input:**
  - Utilizes channels for reflectance, depth, and XYZ coordinates.

- **Loss Functions:**
  - Three loss functions employed:
    - Weighted cross-entropy loss
    - Lovász-Softmax loss
    - Boundary loss

- **Optimizer:**
  - AdamW optimizer.

- **Weight Decay:**
  - Weight decay of 0.0001.

- **Epochs:**
  - 30 epochs.

- **Batch Size:**
  - Batch size of 6.

- **Learning Rate Policy:**
  - One-cycle learning rate policy.
  - Maximum learning rate set to 0.002 initially, then decays to 0.0002 after 30 epochs.

- **Dropout Probability:**
  - Probability of dropout: 0.2.

- **Data Augmentation:**
  - Random rotations, random point dropout, and random sign inversion for X and Y values applied during training with a probability of 0.5.
  - These augmentations increase input point cloud diversity and prevent overfitting.

## Training

In [None]:
p1 = {
    "b1_in": 1,
    "b1_out1": 1,
    "b1_out2": 1,
    "b1_out3": 1,
    "b2_in": 32,
    "b2_out": 32,
    "b3_in": 64,
    "b3_out1": 64,
    "b3_out2": 64,
    "b3_out3": 64,
    "output": 128,
}

p2 = {
    "b1_in": 3,
    "b1_out1": 1,
    "b1_out2": 1,
    "b1_out3": 1,
    "b2_in": 32,
    "b2_out": 32,
    "b3_in": 64,
    "b3_out1": 64,
    "b3_out2": 64,
    "b3_out3": 64,
    "output": 128,
}

In [None]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

In [None]:
model = TransVRNet(p1, p2).to(device)
print(model)

In [None]:
loss_fn = torch.nn.CrossEntropyLoss(weight=)
# lovasz_softmax(probas, labels) so usar a funcao direto, eles dizem pra usar uma softmax antes
# https://github.com/bermanmaxim/LovaszSoftmax/tree/master

dropout = torch.nn.Dropout(p=0.2, inplace=False)
inverter = torchvision.transforms.RandomInvert(p=0.5)
rotation = torchvision.transforms.RandomRotation(degrees=(0, 180))

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-1, weight_decay=0.0001)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, 
                                                max_lr=0.002, 
                                                div_factor=1, 
                                                final_div_factor=10, 
                                                steps_per_epoch=len(data_loader), 
                                                epochs=30)

In [None]:
def total_loss(output, target):
    # the weighted combination of the three loss functions
    weighted_cross_entropy_loss = loss_fn(output, target)
    # lovasz_softmax(probas, labels)
    # boundary loss
    

In [None]:
epochs = 30
for epoch in range(epochs):
    model.train()
    
    for batch_idx, (data, target) in enumerate(data_loader):
        optimizer.zero_grad()
        
        # Data augmentation
        data = inverter(data)
        data = rotation(data)
        
        # Forward pass
        output = model(data)
        
        # Apply dropout
        output = dropout(output)
        
        loss = total_loss(output, target)
        
        loss.backward()
        
        optimizer.step()
        
        scheduler.step()
        
        if batch_idx % log_interval == 0:
            print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item()}")

torch.save(model.state_dict(), "trained_transvrnet_model.pth")