In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from colorcloud.datatools import SemanticKITTIDataset, SphericalProjectionTransform, ProjectionToTensorTransform
from colorcloud.models import RIUNet
from torchvision.transforms import v2
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
import lightning as L

In [None]:
class LitModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.net = RIUNet()
        self.loss_fn = CrossEntropyLoss(reduction='none')

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        img, label, mask = batch
        label[~mask] = 0

        # Compute prediction error
        pred = self.net(img)
        loss = self.loss_fn(pred, label)
        loss = loss[mask].mean()
        return loss

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=1e-3)
        return optimizer

In [None]:
data_path = '../../data'
ds = SemanticKITTIDataset(data_path)

tfms = v2.Compose([
    SphericalProjectionTransform(fov_up_deg=12., fov_down_deg=-26., W=1024, H=64),
    ProjectionToTensorTransform(),
])
ds.set_transform(tfms)

bs = 10
dl = DataLoader(ds, bs, num_workers=4)

In [None]:
# train model
trainer = L.Trainer(max_epochs=1)
trainer.fit(model=LitModel(), train_dataloaders=dl)