In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from colorcloud.datatools import SemanticKITTIDataset, SphericalProjectionTransform, ProjectionToTensorTransform
from colorcloud.models import RIUNet
from torchvision.transforms import v2
from torch.utils.data import DataLoader
import torch

In [None]:
data_path = '../../Cloud2DImageConverter/point_clouds/semantic_kitti/'
ds = SemanticKITTIDataset(data_path)

tfms = v2.Compose([
    SphericalProjectionTransform(fov_up_deg=4., fov_down_deg=-26., W=1024, H=64),
    ProjectionToTensorTransform(),
])
ds.set_transform(tfms)

bs = 3
dl = DataLoader(ds, bs)

In [None]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [None]:
model = RIUNet().to(device)
print(model)

RIUNet(
  (backbone): Sequential(
    (0): Encoder(
      (blocks): ModuleList(
        (0): Block(
          (net): Sequential(
            (0): Conv2d(5, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, padding_mode=circular)
            (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU()
            (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, padding_mode=circular)
            (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (5): ReLU()
          )
        )
        (1): Block(
          (net): Sequential(
            (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, padding_mode=circular)
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU()
            (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1

In [None]:
loss_fn = torch.nn.CrossEntropyLoss(reduction='none')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-1)

In [None]:
n_epochs = 50
model.train()
for epoch in range(n_epochs):
    for batch in dl:
        img, label, mask = batch
        img, label, mask = img.to(device), label.to(device), mask.to(device)
        label[~mask] = 0
        
        # Compute prediction error
        pred = model(img)
        loss = loss_fn(pred, label)
        loss = loss[mask].mean()
        
        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if epoch % 5 == 0:
            print(f"loss: {loss.item():>7f}")

loss: 3.037323
loss: 2.397722
loss: 0.999617
loss: 1.008595
loss: 0.775321
loss: 0.780033
loss: 0.621573
loss: 0.660549
loss: 0.601765
loss: 0.569387
loss: 0.455978
loss: 0.480175
loss: 0.416153
loss: 0.433895
loss: 0.397837
loss: 0.399211
loss: 0.350461
loss: 0.375299
loss: 0.376157
loss: 0.376882
