In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
from colorcloud.cheng2023TransRVNet import TransVRNet
from colorcloud.cheng2023TransRVNet import TransRVNet_loss
from colorcloud.cheng2023TransRVNet import RandomRotationTransform, RandomDroppingPointsTransform, RandomSingInvertingTransform
import lightning as L
import numpy as np
import wandb
from lightning.pytorch.loggers import WandbLogger
from datetime import datetime
import torch
import torchvision
from lovasz_losses import lovasz_softmax
from boundary_loss import BoundaryLoss
from colorcloud.behley2019iccv import SemanticKITTIDataset, SphericalProjection, ProjectionTransform, ProjectionToTensorTransform
from torchvision.transforms import v2
from torch.utils.data import DataLoader

## Data

In [5]:
data_path = '../../data'
ds = SemanticKITTIDataset(data_path)
frame, label, mask = ds[128]

In [6]:
proj = SphericalProjection(fov_up_deg=12., fov_down_deg=-26., W=1024, H=64)
tfms = v2.Compose([
    RandomDroppingPointsTransform(),
    RandomRotationTransform(),
    RandomSingInvertingTransform(),
    ProjectionTransform(proj),
    ProjectionToTensorTransform()
])
ds.set_transform(tfms)
frame_img, label_img, mask_img = ds[128]

In [7]:
bs = 1

data_loader = DataLoader(ds, bs)

## Setup

### Convolutions' Parameters

In [8]:
p1 = {
    "b1_in": 1,
    "b1_out1": 1,
    "b1_out2": 1,
    "b1_out3": 1,
    "b2_in": 32,
    "b2_out": 32,
    "b3_in": 64,
    "b3_out1": 64,
    "b3_out2": 64,
    "b3_out3": 64,
    "output": 128,
}

p2 = {
    "b1_in": 3,
    "b1_out1": 1,
    "b1_out2": 1,
    "b1_out3": 1,
    "b2_in": 32,
    "b2_out": 32,
    "b3_in": 64,
    "b3_out1": 64,
    "b3_out2": 64,
    "b3_out3": 64,
    "output": 128,
}

## Training

In [None]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
# device = "cpu"
print(f"Using {device} device")

In [None]:
model = TransVRNet(p1, p2).to(device)
loss_fn = TransRVNet_loss(device)

In [11]:
dropout = torch.nn.Dropout(p=0.2, inplace=False)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-1, weight_decay=0.0001)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, 
                                                max_lr=0.002, 
                                                div_factor=1, 
                                                final_div_factor=10, 
                                                steps_per_epoch=len(data_loader), 
                                                epochs=30)

In [12]:
for batch in data_loader:
    img, label, mask = batch
    xyz = img[:, :3, :, :]
    reflectance = img[:, 3, :, :].unsqueeze(1)
    depth = img[:, 4, :, :].unsqueeze(1)
    
    break

In [None]:
epochs = 1
for epoch in range(epochs):
    model.train()
    
    for batch_idx, batch in enumerate(data_loader):
        img, label, mask = batch
        img, label, mask = img.to(device), label.to(device), mask.to(device)
        
        # Separate channels
        xyz = img[:, :3, :, :]
        reflectance = img[:, 3, :, :].unsqueeze(1)
        depth = img[:, 4, :, :].unsqueeze(1)
        
        optimizer.zero_grad()

        # Forward pass
        pred = model(reflectance, depth, xyz)
        
        # Apply dropout
        pred = dropout(pred)
        
        label[~mask] = 0
        # Compute prediction error
        loss = loss_fn(pred, label)
        loss = loss[mask].mean()
        # Backpropagation
        loss.backward()
        
        optimizer.step()
        
        scheduler.step()
    
        print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item()}")

torch.save(model.state_dict(), "trained_transvrnet_model.pth")