In [70]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [71]:
from colorcloud.cheng2023TransRVNet import TransVRNet, SemanticSegmentationTask
from colorcloud.cheng2023TransRVNet import TransRVNet_loss
from colorcloud.cheng2023TransRVNet import RandomRotationTransform, RandomDroppingPointsTransform, RandomSingInvertingTransform
import lightning as L
import numpy as np
import wandb
from lightning.pytorch.loggers import WandbLogger
from datetime import datetime
import torch
import torchvision
import lightning as L
from colorcloud.behley2019iccv import SemanticKITTIDataset, SphericalProjection, ProjectionTransform, ProjectionToTensorTransform, SemanticSegmentationLDM
from torchvision.transforms import v2
from torch.utils.data import DataLoader

## Setup

In [72]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


### Convolutions' Parameters

In [73]:
mrciam_p = {
    "p1": {
        "b1_in": 1,
        "b1_out1": 32,
        "b1_out2": 32,
        "b1_out3": 32,
        "b2_in": 64,
        "b2_out": 64,
        "b3_in": 64,
        "b3_out1": 64,
        "b3_out2": 64,
        "b3_out3": 64,
        "output": 64,
    },
    "p2": {
        "b1_in": 3,
        "b1_out1": 32,
        "b1_out2": 32,
        "b1_out3": 32,
        "b2_in": 64,
        "b2_out": 64,
        "b3_in": 64,
        "b3_out1": 64,
        "b3_out2": 64,
        "b3_out3": 64,
        "output": 64,
    },
    "output_conv": 192
}
encoder_p = {
    "module_1": {
        "in_channels": 192,
        "conv2_in_channels": 128,
        "conv2_out_channels": 256,
        "dilated_conv_out_channels": 256,
        "residual_out_channels": 256
    },
    "module_2": {
        "in_channels": 256,
        "conv2_in_channels": 256,
        "conv2_out_channels": 256,
        "dilated_conv_out_channels": 256,
        "residual_out_channels": 256
    }
}

decoder_p = {
    "in_channels": 264,
    "conv2_in_channels": 128,
    "dilated_conv_in_channels": 128,
    "dilated_conv_out_channels": 64,
    "output": 32
}

p_bntm = {
    "window_size": (4,4),
    "embed_dim": int(encoder_p["module_2"]["residual_out_channels"]/8)
}

## Lightning Training

In [74]:
aug_tfms = v2.Compose([
    RandomDroppingPointsTransform(),
    RandomRotationTransform(),
    RandomSingInvertingTransform(),
])

In [75]:
model = TransVRNet(mrciam_p, encoder_p, decoder_p, p_bntm).to(device)
loss_fn = TransRVNet_loss(device) 

In [76]:
data = SemanticSegmentationLDM(proj_kargs={'fov_up_deg':4., 'fov_down_deg':-26., 'W':1024, 'H':64}, proj_style='spherical', 
                               train_batch_size=1, eval_batch_size=1, num_workers=2, aug_tfms=aug_tfms)
data.setup('fit')
epoch_steps = len(data.train_dataloader())

In [77]:
n_epochs = 1
learner = SemanticSegmentationTask(
    model,
    loss_fn,
    data.viz_tfm, 
    total_steps=n_epochs*epoch_steps
)

In [78]:
model_name = "TransVRNet"
timestamp = datetime.now().strftime("%d/%m/%Y_%H:%M:%S")
experiment_name = f"{model_name}_{timestamp}"

wandb_logger = WandbLogger(project="colorcloud", name=experiment_name, log_model="all")
wandb_logger.watch(learner.model, log="all")

[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


In [79]:
# train model
trainer = L.Trainer(max_epochs=n_epochs, logger=wandb_logger)
trainer.fit(learner, data)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type                   | Params | Mode 
------------------------------------------------------------------
0 | model          | TransVRNet             | 10.1 M | train
1 | loss_fn        | TransRVNet_loss        | 0      | train
2 | viz_tfm        | ProjectionVizTransform | 0      | train
3 | dropout        | Dropout                | 0      | train
4 | train_accuracy | MulticlassAccuracy     | 0      | train
5 | val_accuracy   | MulticlassAccuracy     | 0      | train
------------------------------------------------------------------
10.1 M    Trainable params
0         Non-trainable params
10.1 M    Total params
40.577    Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]torch.Size([20, 64, 1024]) tensor([[[[ 6.4618e-03,  1.4485e-02,  3.4077e-03,  ...,  2.1479e-02,
            7.2466e-03, -4.0908e-03],
          [ 9.2486e-04, -1.5524e-03,  2.9876e-03,  ...,  6.8359e-03,
           -4.1401e-05,  3.7444e-03],
          [-1.0838e-03,  8.5473e-04,  1.1149e-02,  ..., -4.2438e-03,
            1.9918e-02,  9.3318e-03],
          ...,
          [ 1.3178e-02,  1.4874e-02,  3.4720e-03,  ...,  2.0119e-02,
           -3.3018e-03,  2.9897e-03],
          [ 1.2037e-02, -2.9710e-03,  6.9147e-03,  ..., -3.6481e-03,
           -4.1545e-03,  6.4552e-04],
          [ 1.1956e-02, -3.8156e-03, -1.4196e-04,  ...,  1.8375e-02,
            7.2472e-03,  2.9572e-03]],

         [[ 1.1329e-01,  1.0386e-01,  1.2492e-01,  ...,  1.0693e-01,
            1.2534e-01,  1.1451e-01],
          [ 1.1529e-01,  1.2323e-01,  1.1141e-01,  ...,  1.0212e-01,
            1.0728e-01,  1.1256e-01],
          [ 1.2044e-01,  1.0743e-