In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
from colorcloud.cheng2023TransRVNet import TransVRNet, SemanticSegmentationTask
from colorcloud.cheng2023TransRVNet import TransRVNet_loss
from colorcloud.cheng2023TransRVNet import RandomRotationTransform, RandomDroppingPointsTransform, RandomSingInvertingTransform
import lightning as L
import numpy as np
import wandb
from lightning.pytorch.loggers import WandbLogger
from datetime import datetime
import torch
import torchvision
import lightning as L
from colorcloud.behley2019iccv import SemanticKITTIDataset, SphericalProjection, ProjectionTransform, ProjectionToTensorTransform, SemanticSegmentationLDM
from torchvision.transforms import v2
from torch.utils.data import DataLoader

## Setup

In [3]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
# device = "cpu"
print(f"Using {device} device")

Using cuda device


### Convolutions' Parameters

In [5]:
mrciam_p = {
    "p1": {
        "b1_in": 1,
        "b1_out1": 1,
        "b1_out2": 1,
        "b1_out3": 1,
        "b2_in": 32,
        "b2_out": 32,
        "b3_in": 64,
        "b3_out1": 64,
        "b3_out2": 64,
        "b3_out3": 64,
        "output": 128,
    },
    "p2": {
        "b1_in": 3,
        "b1_out1": 1,
        "b1_out2": 1,
        "b1_out3": 1,
        "b2_in": 32,
        "b2_out": 32,
        "b3_in": 64,
        "b3_out1": 64,
        "b3_out2": 64,
        "b3_out3": 64,
        "output": 128,
    },
    "output_conv": 384
}

encoder_p = {
    "module_1": {
        "in_channels": 384,
        "conv2_in_channels": 64,
        "conv2_out_channels": 64,
        "dilated_conv_out_channels": 64,
        "residual_out_channels": 64
    },
    "module_2": {
        "in_channels": 64,
        "out_channels": 64,
        "conv2_in_channels": 64,
        "conv2_out_channels": 64,
        "dilated_conv_out_channels": 64,
        "residual_out_channels": 64
    }
}

## Lightning Training

In [6]:
proj = SphericalProjection(fov_up_deg=12., fov_down_deg=-26., W=1024, H=64)
tfms = v2.Compose([
    RandomDroppingPointsTransform(),
    RandomRotationTransform(),
    RandomSingInvertingTransform(),
    ProjectionTransform(proj),
    ProjectionToTensorTransform()
])

In [None]:
model = TransVRNet(mrciam_p, encoder_p).to(device)
loss_fn = TransRVNet_loss(device)

In [None]:
data = SemanticSegmentationLDM(proj_kargs={'W': 1024, 'H': 64}, train_batch_size=1, eval_batch_size=1, num_workers=2)
data.setup('fit')
epoch_steps = len(data.train_dataloader())

In [9]:
n_epochs = 1
learner = SemanticSegmentationTask(
    model,
    loss_fn,
    data.viz_tfm, 
    total_steps=n_epochs*epoch_steps
)

In [None]:
model_name = "TransVRNet"
timestamp = datetime.now().strftime("%d/%m/%Y_%H:%M:%S")
experiment_name = f"{model_name}_{timestamp}"

wandb_logger = WandbLogger(project="colorcloud", name=experiment_name, log_model="all")
wandb_logger.watch(learner.model, log="all")

In [None]:
# train model
trainer = L.Trainer(max_epochs=n_epochs, logger=wandb_logger)
trainer.fit(learner, data)